diff --git a/bin/train.py b/bin/train.py index 5f7724a..1bf4052 100644 --- a/bin/train.py +++ b/bin/train.py @@ -49,8 +49,8 @@ torch.backends.cudnn.benchmark = False def plot_timestep_distributions( train_dset, timesteps: int, - shift_angles_zero_twopi: bool, plots_folder: Path, + shift_angles_zero_twopi: bool = False, n_intervals: int = 11, ) -> None: """ @@ -104,21 +104,21 @@ def plot_kl_divergence(train_dset, plots_folder: Path) -> None: def get_train_valid_test_sets( angles_definitions: Literal[ - "rosetta", "canonical", "canonical_angles_only", "canonical_dihedrals_only", + "rosetta", + "canonical", + "canonical_angles_only", + "canonical_dihedrals_only", ] = "rosetta", # Keep this default at rosetta for compatibility max_seq_len: int = 512, min_seq_len: int = 0, seq_trim_strategy: Literal["leftalign", "randomcrop"] = "leftalign", timesteps: int = 250, variance_schedule: SCHEDULES = "linear", - noise_prior: Literal["gaussian", "uniform"] = "gaussian", - shift_to_zero_twopi: bool = False, zero_center: bool = False, var_scale: float = np.pi, toy: Union[int, bool] = False, exhaustive_t: bool = False, syn_noiser: str = "", - single_dist_debug: bool = False, single_angle_debug: int = -1, # Noise and return a single angle. -1 to disable, 1-3 for omega/theta/phi single_time_debug: bool = False, # Noise and return a single time ) -> Tuple[Dataset, Dataset, Dataset]: @@ -145,7 +145,6 @@ def get_train_valid_test_sets( pad=max_seq_len, min_length=min_seq_len, trim_strategy=seq_trim_strategy, - shift_to_zero_twopi=shift_to_zero_twopi, zero_center=zero_center, toy=toy, ) @@ -163,11 +162,8 @@ def get_train_valid_test_sets( dset_noiser_class = datasets.SynNoisedByPositionDataset else: raise ValueError(f"Unknown synthetic noiser {syn_noiser}") - elif noise_prior == "gaussian": - if single_dist_debug: - logging.warning("Using single dist debug") - dset_noiser_class = datasets.SingleNoisedBondDistanceDataset - elif single_angle_debug > 0: + else: + if single_angle_debug > 0: logging.warning("Using single angle noise!") dset_noiser_class = functools.partial( datasets.SingleNoisedAngleDataset, ft_idx=single_angle_debug @@ -177,10 +173,6 @@ def get_train_valid_test_sets( dset_noiser_class = datasets.SingleNoisedAngleAndTimeDataset else: dset_noiser_class = datasets.NoisedAnglesDataset - elif noise_prior == "uniform": - dset_noiser_class = datasets.GaussianDistUniformAnglesNoisedDataset - else: - raise ValueError(f"Unrecognized noise prior: {noise_prior}") logging.info(f"Using {dset_noiser_class} for noise") noised_dsets = [ @@ -190,7 +182,6 @@ def get_train_valid_test_sets( timesteps=timesteps, exhaustive_t=(i != 0) and exhaustive_t, beta_schedule=variance_schedule, - shift_to_zero_twopi=shift_to_zero_twopi, nonangular_variance=1.0, angular_variance=var_scale, ) @@ -291,16 +282,11 @@ def train( max_seq_len: int = 512, min_seq_len: int = 0, # 0 means no filtering based on min sequence length trim_strategy: Literal["leftalign", "randomcrop"] = "leftalign", - shift_angles_zero_twopi: bool = False, zero_center: bool = False, - noise_prior: Literal["gaussian", "uniform"] = "gaussian", # Uniform not tested timesteps: int = 250, variance_schedule: SCHEDULES = "linear", # cosine better on single angle toy test variance_scale: float = 1.0, # Related to model architecture - implementation: Literal[ - "pytorch_encoder", "huggingface_encoder" - ] = "huggingface_encoder", time_encoding: Literal["gaussian_fourier", "sinusoidal"] = "gaussian_fourier", num_hidden_layers: int = 12, # Default 12 hidden_size: int = 384, # Default 768 @@ -329,7 +315,6 @@ def train( subset: Union[bool, int] = False, # Subset to n training examples exhaustive_validation_t: bool = False, # Exhaustively enumerate t for validation/test syn_noiser: str = "", # If specified, use a synthetic noiser - single_dist_debug: bool = False, # Debug on distance (no periodicity) single_angle_debug: int = -1, # Noise and return a single angle, choose [1, 2, 3] or -1 to disable single_timestep_debug: bool = False, # Noise and return a single timestep cpu_only: bool = False, @@ -353,14 +338,11 @@ def train( seq_trim_strategy=trim_strategy, timesteps=timesteps, variance_schedule=variance_schedule, - noise_prior=noise_prior, - shift_to_zero_twopi=shift_angles_zero_twopi, zero_center=zero_center, var_scale=variance_scale, toy=subset, syn_noiser=syn_noiser, exhaustive_t=exhaustive_validation_t, - single_dist_debug=single_dist_debug, single_angle_debug=single_angle_debug, single_time_debug=single_timestep_debug, ) @@ -393,7 +375,6 @@ def train( if ( single_angle_debug < 0 and not single_timestep_debug - and not single_dist_debug and not syn_noiser and not dryrun ): @@ -401,7 +382,6 @@ def train( plot_timestep_distributions( dsets[0], timesteps=timesteps, - shift_angles_zero_twopi=shift_angles_zero_twopi, plots_folder=plots_folder, ) @@ -411,8 +391,6 @@ def train( loss_fn = loss if single_angle_debug > 0 or single_timestep_debug or syn_noiser: loss_fn = functools.partial(losses.radian_smooth_l1_loss, beta=0.1 * np.pi) - elif single_dist_debug: - loss_fn = F.smooth_l1_loss logging.info(f"Using loss function: {loss_fn}") # Shape of the input is (batch_size, timesteps, features) @@ -420,41 +398,38 @@ def train( model_n_inputs = sample_input.shape[-1] logging.info(f"Auto detected {model_n_inputs} inputs") - if implementation == "huggingface_encoder": - logging.info("Using HuggingFace encoder implementation") - cfg = BertConfig( - max_position_embeddings=max_seq_len, - num_attention_heads=num_heads, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - num_hidden_layers=num_hidden_layers, - position_embedding_type=position_embedding_type, - hidden_dropout_prob=dropout_p, - attention_probs_dropout_prob=dropout_p, - use_cache=False, - ) - # ft_is_angular from the clean datasets angularity definition - model = modelling.BertForDiffusion( - cfg, - time_encoding=time_encoding, - decoder=decoder, - ft_is_angular=dsets[0].dset.feature_is_angular["angles"], - ft_names=dsets[0].dset.feature_names["angles"], - lr=lr, - loss=loss_fn, - l2=l2_norm, - l1=l1_norm, - circle_reg=circle_reg, - epochs=max_epochs, - steps_per_epoch=len(train_dataloader), - lr_scheduler=lr_scheduler, - write_preds_to_dir=results_folder / "valid_preds" - if write_valid_preds - else None, - ) - cfg.save_pretrained(results_folder) - else: - raise ValueError(f"Unknown implementation: {implementation}") + logging.info("Using HuggingFace encoder implementation") + cfg = BertConfig( + max_position_embeddings=max_seq_len, + num_attention_heads=num_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + position_embedding_type=position_embedding_type, + hidden_dropout_prob=dropout_p, + attention_probs_dropout_prob=dropout_p, + use_cache=False, + ) + # ft_is_angular from the clean datasets angularity definition + model = modelling.BertForDiffusion( + cfg, + time_encoding=time_encoding, + decoder=decoder, + ft_is_angular=dsets[0].dset.feature_is_angular["angles"], + ft_names=dsets[0].dset.feature_names["angles"], + lr=lr, + loss=loss_fn, + l2=l2_norm, + l1=l1_norm, + circle_reg=circle_reg, + epochs=max_epochs, + steps_per_epoch=len(train_dataloader), + lr_scheduler=lr_scheduler, + write_preds_to_dir=results_folder / "valid_preds" + if write_valid_preds + else None, + ) + cfg.save_pretrained(results_folder) callbacks = build_callbacks( outdir=results_folder, early_stop_patience=early_stop_patience, swa=use_swa @@ -506,7 +481,8 @@ def build_parser() -> argparse.ArgumentParser: Build CLI parser """ parser = argparse.ArgumentParser( - usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter, + usage=__doc__, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # https://stackoverflow.com/questions/4480075/argparse-optional-positional-arguments @@ -533,7 +509,6 @@ def build_parser() -> argparse.ArgumentParser: default=None, help="Use a toy dataset of n items rather than full dataset", ) - parser.add_argument("--debug_dist", action="store_true", help="Debug distances") parser.add_argument( "--debug_single_time", action="store_true", @@ -562,7 +537,6 @@ def main(): { "results_dir": args.outdir, "subset": args.toy, - "single_dist_debug": args.debug_dist, "single_timestep_debug": args.debug_single_time, "cpu_only": args.cpu, "ngpu": args.ngpu,