From e84bd412e99aa2b6735510cfbd9c4c7e3840da7b Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Fri, 9 Sep 2022 11:46:38 -0700 Subject: [PATCH] Reorganize typing literals --- bin/train.py | 43 +++++++++++++++++++++++++++++-------------- protdiff/datasets.py | 15 +++++++++------ protdiff/modelling.py | 27 ++++++++++++++++++++------- 3 files changed, 58 insertions(+), 27 deletions(-) diff --git a/bin/train.py b/bin/train.py index 3b90e07..b945ffe 100644 --- a/bin/train.py +++ b/bin/train.py @@ -1,5 +1,7 @@ """ -Training script +Training script. + +Example usage: python ~/protdiff/bin/train.py ~/protdiff/config_jsons/full_run_canonical_angles_only_zero_centered_1000_timesteps_reduced_len.json """ import os, sys @@ -45,6 +47,11 @@ torch.manual_seed(6489) # torch.use_deterministic_algorithms(True) torch.backends.cudnn.benchmark = False +# Define some typing literals +ANGLES_DEFINITIONS = Literal[ + "canonical", "canonical-full-angles", "canonical-minimal-angles" +] + @pl.utilities.rank_zero_only def plot_timestep_distributions( @@ -83,6 +90,11 @@ def plot_kl_divergence(train_dset, plots_folder: Path) -> None: """ Plot the KL divergence over time """ + # This works because the main body of this script should clean out the dir + # between runs + outname = plots_folder / "kl_divergence_timesteps.pdf" + if outname.is_file(): + logging.info(f"KL divergence plot exists at {outname}; skipping...") kl_at_timesteps = cm.kl_from_dset(train_dset) # Shape (n_timesteps, n_features) n_timesteps, n_features = kl_at_timesteps.shape fig, axes = plt.subplots( @@ -98,14 +110,14 @@ def plot_kl_divergence(train_dset, plots_folder: Path) -> None: fig.suptitle( f"KL(empirical || Gaussian) over timesteps={train_dset.timesteps}", y=1.05 ) - fig.savefig(plots_folder / "kl_divergence_timesteps.pdf", bbox_inches="tight") + fig.savefig(outname, bbox_inches="tight") def get_train_valid_test_sets( - angles_definitions: datasets.ANGLES_DEFINITIONS = "canonical-full-angles", + angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles", max_seq_len: int = 512, min_seq_len: int = 0, - seq_trim_strategy: Literal["leftalign", "randomcrop"] = "leftalign", + seq_trim_strategy: datasets.TRIM_STRATEGIES = "leftalign", timesteps: int = 250, variance_schedule: SCHEDULES = "linear", zero_center: bool = False, @@ -269,16 +281,16 @@ def train( # Controls output results_dir: str = "./results", # Controls data loading and noising process - angles_definitions: datasets.ANGLES_DEFINITIONS = "canonical", + angles_definitions: ANGLES_DEFINITIONS = "canonical", max_seq_len: int = 512, min_seq_len: int = 0, # 0 means no filtering based on min sequence length - trim_strategy: Literal["leftalign", "randomcrop"] = "leftalign", + trim_strategy: datasets.TRIM_STRATEGIES = "leftalign", zero_center: bool = False, timesteps: int = 250, variance_schedule: SCHEDULES = "linear", # cosine better on single angle toy test variance_scale: float = 1.0, # Related to model architecture - time_encoding: Literal["gaussian_fourier", "sinusoidal"] = "gaussian_fourier", + time_encoding: modelling.TIME_ENCODING = "gaussian_fourier", num_hidden_layers: int = 12, # Default 12 hidden_size: int = 384, # Default 768 intermediate_size: int = 768, # Default 3072 @@ -287,19 +299,19 @@ def train( "absolute", "relative_key", "relative_key_query" ] = "absolute", # relative_key = https://arxiv.org/pdf/1803.02155.pdf | relative_key_query = https://arxiv.org/pdf/2009.13658.pdf dropout_p: float = 0.1, # Default 0.1, can disable for debugging - decoder: Literal["mlp", "linear"] = "mlp", + decoder: modelling.DECODER_HEAD = "mlp", # Related to training strategy gradient_clip: float = 1.0, # From BERT trainer batch_size: int = 64, lr: float = 5e-5, # Default lr for huggingface BERT trainer - loss: Literal["l1", "smooth_l1"] = "smooth_l1", + loss: modelling.LOSS_KEYS = "smooth_l1", l2_norm: float = 0.0, # AdamW default has 0.01 L2 regularization, but BERT trainer uses 0.0 l1_norm: float = 0.0, circle_reg: float = 0.0, min_epochs: Optional[int] = None, max_epochs: int = 10000, early_stop_patience: int = 0, # Set to 0 to disable early stopping - lr_scheduler: Optional[Literal["OneCycleLR", "LinearWarmup"]] = None, + lr_scheduler: Optional[modelling.LR_SCHEDULE] = None, use_swa: bool = False, # Stochastic weight averaging can improve training genearlization # Misc. and debugging multithread: bool = True, @@ -539,9 +551,12 @@ def main(): if __name__ == "__main__": curr_time = datetime.now().strftime("%y%m%d_%H%M%S") - logging.basicConfig(level=logging.INFO, handlers=[ - logging.FileHandler(f"training_{curr_time}.log"), - logging.StreamHandler() - ]) + logging.basicConfig( + level=logging.INFO, + handlers=[ + logging.FileHandler(f"training_{curr_time}.log"), + logging.StreamHandler(), + ], + ) main() diff --git a/protdiff/datasets.py b/protdiff/datasets.py index 01bf28a..9cfa632 100644 --- a/protdiff/datasets.py +++ b/protdiff/datasets.py @@ -41,9 +41,7 @@ from angles_and_coords import ( from custom_metrics import kl_from_empirical, wrapped_mean import utils -ANGLES_DEFINITIONS = Literal[ - "canonical", "canonical-full-angles", "canonical-minimal-angles" -] +TRIM_STRATEGIES = Literal["leftalign", "randomcrop"] class CathConsecutiveAnglesDataset(Dataset): @@ -255,7 +253,7 @@ class CathCanonicalAnglesDataset(Dataset): split: Optional[Literal["train", "test", "validation"]] = None, pad: int = 512, min_length: int = 40, # Set to 0 to disable - trim_strategy: Literal["leftalign", "randomcrop"] = "leftalign", + trim_strategy: TRIM_STRATEGIES = "leftalign", toy: int = 0, zero_center: bool = False, # Center the features to have 0 mean use_cache: bool = False, # Use/build cached computations of dihedrals and angles @@ -743,7 +741,10 @@ class NoisedAnglesDataset(Dataset): return noise def __getitem__( - self, index: int, use_t_val: Optional[int] = None, ignore_zero_center: bool = False + self, + index: int, + use_t_val: Optional[int] = None, + ignore_zero_center: bool = False, ) -> Dict[str, torch.Tensor]: """ Gets the i-th item in the dataset and adds noise @@ -761,7 +762,9 @@ class NoisedAnglesDataset(Dataset): assert ( item_index * self.timesteps + time_index == index ), f"Unexpected indices for {index} -- {item_index} {time_index}" - item = self.dset.__getitem__(item_index, ignore_zero_center=ignore_zero_center) + item = self.dset.__getitem__( + item_index, ignore_zero_center=ignore_zero_center + ) else: item = self.dset.__getitem__(index, ignore_zero_center=ignore_zero_center) diff --git a/protdiff/modelling.py b/protdiff/modelling.py index 0eaa1c2..b4d27a9 100644 --- a/protdiff/modelling.py +++ b/protdiff/modelling.py @@ -29,6 +29,11 @@ from transformers.optimization import get_linear_schedule_with_warmup import losses import utils +LR_SCHEDULE = Literal["OneCycleLR", "LinearWarmup"] +TIME_ENCODING = Literal["gaussian_fourier", "sinusoidal"] +LOSS_KEYS = Literal["l1", "smooth_l1"] +DECODER_HEAD = Literal["mlp", "linear"] + class GaussianFourierProjection(nn.Module): """ @@ -140,7 +145,9 @@ class BertEmbeddings(nn.Module): # position_ids (1, len position emb) is contiguous in memory and exported when serialized def forward( - self, input_embeds: torch.Tensor, position_ids: torch.LongTensor, + self, + input_embeds: torch.Tensor, + position_ids: torch.LongTensor, ) -> torch.Tensor: assert position_ids is not None, "`position_ids` must be defined" embeddings = input_embeds @@ -222,16 +229,16 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule): config, ft_is_angular: List[bool] = [False, True, True, True], ft_names: Optional[List[str]] = None, - time_encoding: Literal["gaussian_fourier", "sinusoidal"] = "gaussian_fourier", - decoder: Literal["linear", "mlp"] = "mlp", + time_encoding: TIME_ENCODING = "gaussian_fourier", + decoder: DECODER_HEAD = "mlp", lr: float = 5e-5, - loss: Union[Callable, Literal["l1", "smooth_l1"]] = "smooth_l1", + loss: Union[Callable, LOSS_KEYS] = "smooth_l1", l2: float = 0.0, l1: float = 0.0, circle_reg: float = 0.0, epochs: int = 1, steps_per_epoch: int = 250, # Dummy value - lr_scheduler: Optional[Literal["OneCycleLR", "LinearWarmup"]] = None, + lr_scheduler: Optional[LR_SCHEDULE] = None, write_preds_to_dir: Optional[str] = None, ) -> None: """ @@ -470,7 +477,11 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule): if position_ids is None: # [1, seq_length] position_ids = ( - torch.arange(seq_length,).expand(batch_size, -1).type_as(timestep) + torch.arange( + seq_length, + ) + .expand(batch_size, -1) + .type_as(timestep) ) # pl.utilities.rank_zero_debug( @@ -661,7 +672,9 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule): Return optimizer. Limited support for some optimizers """ optim = torch.optim.AdamW( - self.parameters(), lr=self.learning_rate, weight_decay=self.l2_lambda, + self.parameters(), + lr=self.learning_rate, + weight_decay=self.l2_lambda, ) retval = {"optimizer": optim} if self.lr_scheduler: