mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Remove old code
This commit is contained in:
108
bin/train.py
108
bin/train.py
@@ -49,8 +49,8 @@ torch.backends.cudnn.benchmark = False
|
||||
def plot_timestep_distributions(
|
||||
train_dset,
|
||||
timesteps: int,
|
||||
shift_angles_zero_twopi: bool,
|
||||
plots_folder: Path,
|
||||
shift_angles_zero_twopi: bool = False,
|
||||
n_intervals: int = 11,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -104,21 +104,21 @@ def plot_kl_divergence(train_dset, plots_folder: Path) -> None:
|
||||
|
||||
def get_train_valid_test_sets(
|
||||
angles_definitions: Literal[
|
||||
"rosetta", "canonical", "canonical_angles_only", "canonical_dihedrals_only",
|
||||
"rosetta",
|
||||
"canonical",
|
||||
"canonical_angles_only",
|
||||
"canonical_dihedrals_only",
|
||||
] = "rosetta", # Keep this default at rosetta for compatibility
|
||||
max_seq_len: int = 512,
|
||||
min_seq_len: int = 0,
|
||||
seq_trim_strategy: Literal["leftalign", "randomcrop"] = "leftalign",
|
||||
timesteps: int = 250,
|
||||
variance_schedule: SCHEDULES = "linear",
|
||||
noise_prior: Literal["gaussian", "uniform"] = "gaussian",
|
||||
shift_to_zero_twopi: bool = False,
|
||||
zero_center: bool = False,
|
||||
var_scale: float = np.pi,
|
||||
toy: Union[int, bool] = False,
|
||||
exhaustive_t: bool = False,
|
||||
syn_noiser: str = "",
|
||||
single_dist_debug: bool = False,
|
||||
single_angle_debug: int = -1, # Noise and return a single angle. -1 to disable, 1-3 for omega/theta/phi
|
||||
single_time_debug: bool = False, # Noise and return a single time
|
||||
) -> Tuple[Dataset, Dataset, Dataset]:
|
||||
@@ -145,7 +145,6 @@ def get_train_valid_test_sets(
|
||||
pad=max_seq_len,
|
||||
min_length=min_seq_len,
|
||||
trim_strategy=seq_trim_strategy,
|
||||
shift_to_zero_twopi=shift_to_zero_twopi,
|
||||
zero_center=zero_center,
|
||||
toy=toy,
|
||||
)
|
||||
@@ -163,11 +162,8 @@ def get_train_valid_test_sets(
|
||||
dset_noiser_class = datasets.SynNoisedByPositionDataset
|
||||
else:
|
||||
raise ValueError(f"Unknown synthetic noiser {syn_noiser}")
|
||||
elif noise_prior == "gaussian":
|
||||
if single_dist_debug:
|
||||
logging.warning("Using single dist debug")
|
||||
dset_noiser_class = datasets.SingleNoisedBondDistanceDataset
|
||||
elif single_angle_debug > 0:
|
||||
else:
|
||||
if single_angle_debug > 0:
|
||||
logging.warning("Using single angle noise!")
|
||||
dset_noiser_class = functools.partial(
|
||||
datasets.SingleNoisedAngleDataset, ft_idx=single_angle_debug
|
||||
@@ -177,10 +173,6 @@ def get_train_valid_test_sets(
|
||||
dset_noiser_class = datasets.SingleNoisedAngleAndTimeDataset
|
||||
else:
|
||||
dset_noiser_class = datasets.NoisedAnglesDataset
|
||||
elif noise_prior == "uniform":
|
||||
dset_noiser_class = datasets.GaussianDistUniformAnglesNoisedDataset
|
||||
else:
|
||||
raise ValueError(f"Unrecognized noise prior: {noise_prior}")
|
||||
|
||||
logging.info(f"Using {dset_noiser_class} for noise")
|
||||
noised_dsets = [
|
||||
@@ -190,7 +182,6 @@ def get_train_valid_test_sets(
|
||||
timesteps=timesteps,
|
||||
exhaustive_t=(i != 0) and exhaustive_t,
|
||||
beta_schedule=variance_schedule,
|
||||
shift_to_zero_twopi=shift_to_zero_twopi,
|
||||
nonangular_variance=1.0,
|
||||
angular_variance=var_scale,
|
||||
)
|
||||
@@ -291,16 +282,11 @@ def train(
|
||||
max_seq_len: int = 512,
|
||||
min_seq_len: int = 0, # 0 means no filtering based on min sequence length
|
||||
trim_strategy: Literal["leftalign", "randomcrop"] = "leftalign",
|
||||
shift_angles_zero_twopi: bool = False,
|
||||
zero_center: bool = False,
|
||||
noise_prior: Literal["gaussian", "uniform"] = "gaussian", # Uniform not tested
|
||||
timesteps: int = 250,
|
||||
variance_schedule: SCHEDULES = "linear", # cosine better on single angle toy test
|
||||
variance_scale: float = 1.0,
|
||||
# Related to model architecture
|
||||
implementation: Literal[
|
||||
"pytorch_encoder", "huggingface_encoder"
|
||||
] = "huggingface_encoder",
|
||||
time_encoding: Literal["gaussian_fourier", "sinusoidal"] = "gaussian_fourier",
|
||||
num_hidden_layers: int = 12, # Default 12
|
||||
hidden_size: int = 384, # Default 768
|
||||
@@ -329,7 +315,6 @@ def train(
|
||||
subset: Union[bool, int] = False, # Subset to n training examples
|
||||
exhaustive_validation_t: bool = False, # Exhaustively enumerate t for validation/test
|
||||
syn_noiser: str = "", # If specified, use a synthetic noiser
|
||||
single_dist_debug: bool = False, # Debug on distance (no periodicity)
|
||||
single_angle_debug: int = -1, # Noise and return a single angle, choose [1, 2, 3] or -1 to disable
|
||||
single_timestep_debug: bool = False, # Noise and return a single timestep
|
||||
cpu_only: bool = False,
|
||||
@@ -353,14 +338,11 @@ def train(
|
||||
seq_trim_strategy=trim_strategy,
|
||||
timesteps=timesteps,
|
||||
variance_schedule=variance_schedule,
|
||||
noise_prior=noise_prior,
|
||||
shift_to_zero_twopi=shift_angles_zero_twopi,
|
||||
zero_center=zero_center,
|
||||
var_scale=variance_scale,
|
||||
toy=subset,
|
||||
syn_noiser=syn_noiser,
|
||||
exhaustive_t=exhaustive_validation_t,
|
||||
single_dist_debug=single_dist_debug,
|
||||
single_angle_debug=single_angle_debug,
|
||||
single_time_debug=single_timestep_debug,
|
||||
)
|
||||
@@ -393,7 +375,6 @@ def train(
|
||||
if (
|
||||
single_angle_debug < 0
|
||||
and not single_timestep_debug
|
||||
and not single_dist_debug
|
||||
and not syn_noiser
|
||||
and not dryrun
|
||||
):
|
||||
@@ -401,7 +382,6 @@ def train(
|
||||
plot_timestep_distributions(
|
||||
dsets[0],
|
||||
timesteps=timesteps,
|
||||
shift_angles_zero_twopi=shift_angles_zero_twopi,
|
||||
plots_folder=plots_folder,
|
||||
)
|
||||
|
||||
@@ -411,8 +391,6 @@ def train(
|
||||
loss_fn = loss
|
||||
if single_angle_debug > 0 or single_timestep_debug or syn_noiser:
|
||||
loss_fn = functools.partial(losses.radian_smooth_l1_loss, beta=0.1 * np.pi)
|
||||
elif single_dist_debug:
|
||||
loss_fn = F.smooth_l1_loss
|
||||
logging.info(f"Using loss function: {loss_fn}")
|
||||
|
||||
# Shape of the input is (batch_size, timesteps, features)
|
||||
@@ -420,41 +398,38 @@ def train(
|
||||
model_n_inputs = sample_input.shape[-1]
|
||||
logging.info(f"Auto detected {model_n_inputs} inputs")
|
||||
|
||||
if implementation == "huggingface_encoder":
|
||||
logging.info("Using HuggingFace encoder implementation")
|
||||
cfg = BertConfig(
|
||||
max_position_embeddings=max_seq_len,
|
||||
num_attention_heads=num_heads,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
position_embedding_type=position_embedding_type,
|
||||
hidden_dropout_prob=dropout_p,
|
||||
attention_probs_dropout_prob=dropout_p,
|
||||
use_cache=False,
|
||||
)
|
||||
# ft_is_angular from the clean datasets angularity definition
|
||||
model = modelling.BertForDiffusion(
|
||||
cfg,
|
||||
time_encoding=time_encoding,
|
||||
decoder=decoder,
|
||||
ft_is_angular=dsets[0].dset.feature_is_angular["angles"],
|
||||
ft_names=dsets[0].dset.feature_names["angles"],
|
||||
lr=lr,
|
||||
loss=loss_fn,
|
||||
l2=l2_norm,
|
||||
l1=l1_norm,
|
||||
circle_reg=circle_reg,
|
||||
epochs=max_epochs,
|
||||
steps_per_epoch=len(train_dataloader),
|
||||
lr_scheduler=lr_scheduler,
|
||||
write_preds_to_dir=results_folder / "valid_preds"
|
||||
if write_valid_preds
|
||||
else None,
|
||||
)
|
||||
cfg.save_pretrained(results_folder)
|
||||
else:
|
||||
raise ValueError(f"Unknown implementation: {implementation}")
|
||||
logging.info("Using HuggingFace encoder implementation")
|
||||
cfg = BertConfig(
|
||||
max_position_embeddings=max_seq_len,
|
||||
num_attention_heads=num_heads,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
position_embedding_type=position_embedding_type,
|
||||
hidden_dropout_prob=dropout_p,
|
||||
attention_probs_dropout_prob=dropout_p,
|
||||
use_cache=False,
|
||||
)
|
||||
# ft_is_angular from the clean datasets angularity definition
|
||||
model = modelling.BertForDiffusion(
|
||||
cfg,
|
||||
time_encoding=time_encoding,
|
||||
decoder=decoder,
|
||||
ft_is_angular=dsets[0].dset.feature_is_angular["angles"],
|
||||
ft_names=dsets[0].dset.feature_names["angles"],
|
||||
lr=lr,
|
||||
loss=loss_fn,
|
||||
l2=l2_norm,
|
||||
l1=l1_norm,
|
||||
circle_reg=circle_reg,
|
||||
epochs=max_epochs,
|
||||
steps_per_epoch=len(train_dataloader),
|
||||
lr_scheduler=lr_scheduler,
|
||||
write_preds_to_dir=results_folder / "valid_preds"
|
||||
if write_valid_preds
|
||||
else None,
|
||||
)
|
||||
cfg.save_pretrained(results_folder)
|
||||
|
||||
callbacks = build_callbacks(
|
||||
outdir=results_folder, early_stop_patience=early_stop_patience, swa=use_swa
|
||||
@@ -506,7 +481,8 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
Build CLI parser
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
usage=__doc__,
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# https://stackoverflow.com/questions/4480075/argparse-optional-positional-arguments
|
||||
@@ -533,7 +509,6 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="Use a toy dataset of n items rather than full dataset",
|
||||
)
|
||||
parser.add_argument("--debug_dist", action="store_true", help="Debug distances")
|
||||
parser.add_argument(
|
||||
"--debug_single_time",
|
||||
action="store_true",
|
||||
@@ -562,7 +537,6 @@ def main():
|
||||
{
|
||||
"results_dir": args.outdir,
|
||||
"subset": args.toy,
|
||||
"single_dist_debug": args.debug_dist,
|
||||
"single_timestep_debug": args.debug_single_time,
|
||||
"cpu_only": args.cpu,
|
||||
"ngpu": args.ngpu,
|
||||
|
||||
Reference in New Issue
Block a user