Reorganize typing literals

This commit is contained in:
Kevin Wu
2022-09-09 11:46:38 -07:00
parent 0b095dbbda
commit e84bd412e9
3 changed files with 58 additions and 27 deletions

View File

@@ -1,5 +1,7 @@
"""
Training script
Training script.
Example usage: python ~/protdiff/bin/train.py ~/protdiff/config_jsons/full_run_canonical_angles_only_zero_centered_1000_timesteps_reduced_len.json
"""
import os, sys
@@ -45,6 +47,11 @@ torch.manual_seed(6489)
# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
# Define some typing literals
ANGLES_DEFINITIONS = Literal[
"canonical", "canonical-full-angles", "canonical-minimal-angles"
]
@pl.utilities.rank_zero_only
def plot_timestep_distributions(
@@ -83,6 +90,11 @@ def plot_kl_divergence(train_dset, plots_folder: Path) -> None:
"""
Plot the KL divergence over time
"""
# This works because the main body of this script should clean out the dir
# between runs
outname = plots_folder / "kl_divergence_timesteps.pdf"
if outname.is_file():
logging.info(f"KL divergence plot exists at {outname}; skipping...")
kl_at_timesteps = cm.kl_from_dset(train_dset) # Shape (n_timesteps, n_features)
n_timesteps, n_features = kl_at_timesteps.shape
fig, axes = plt.subplots(
@@ -98,14 +110,14 @@ def plot_kl_divergence(train_dset, plots_folder: Path) -> None:
fig.suptitle(
f"KL(empirical || Gaussian) over timesteps={train_dset.timesteps}", y=1.05
)
fig.savefig(plots_folder / "kl_divergence_timesteps.pdf", bbox_inches="tight")
fig.savefig(outname, bbox_inches="tight")
def get_train_valid_test_sets(
angles_definitions: datasets.ANGLES_DEFINITIONS = "canonical-full-angles",
angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles",
max_seq_len: int = 512,
min_seq_len: int = 0,
seq_trim_strategy: Literal["leftalign", "randomcrop"] = "leftalign",
seq_trim_strategy: datasets.TRIM_STRATEGIES = "leftalign",
timesteps: int = 250,
variance_schedule: SCHEDULES = "linear",
zero_center: bool = False,
@@ -269,16 +281,16 @@ def train(
# Controls output
results_dir: str = "./results",
# Controls data loading and noising process
angles_definitions: datasets.ANGLES_DEFINITIONS = "canonical",
angles_definitions: ANGLES_DEFINITIONS = "canonical",
max_seq_len: int = 512,
min_seq_len: int = 0, # 0 means no filtering based on min sequence length
trim_strategy: Literal["leftalign", "randomcrop"] = "leftalign",
trim_strategy: datasets.TRIM_STRATEGIES = "leftalign",
zero_center: bool = False,
timesteps: int = 250,
variance_schedule: SCHEDULES = "linear", # cosine better on single angle toy test
variance_scale: float = 1.0,
# Related to model architecture
time_encoding: Literal["gaussian_fourier", "sinusoidal"] = "gaussian_fourier",
time_encoding: modelling.TIME_ENCODING = "gaussian_fourier",
num_hidden_layers: int = 12, # Default 12
hidden_size: int = 384, # Default 768
intermediate_size: int = 768, # Default 3072
@@ -287,19 +299,19 @@ def train(
"absolute", "relative_key", "relative_key_query"
] = "absolute", # relative_key = https://arxiv.org/pdf/1803.02155.pdf | relative_key_query = https://arxiv.org/pdf/2009.13658.pdf
dropout_p: float = 0.1, # Default 0.1, can disable for debugging
decoder: Literal["mlp", "linear"] = "mlp",
decoder: modelling.DECODER_HEAD = "mlp",
# Related to training strategy
gradient_clip: float = 1.0, # From BERT trainer
batch_size: int = 64,
lr: float = 5e-5, # Default lr for huggingface BERT trainer
loss: Literal["l1", "smooth_l1"] = "smooth_l1",
loss: modelling.LOSS_KEYS = "smooth_l1",
l2_norm: float = 0.0, # AdamW default has 0.01 L2 regularization, but BERT trainer uses 0.0
l1_norm: float = 0.0,
circle_reg: float = 0.0,
min_epochs: Optional[int] = None,
max_epochs: int = 10000,
early_stop_patience: int = 0, # Set to 0 to disable early stopping
lr_scheduler: Optional[Literal["OneCycleLR", "LinearWarmup"]] = None,
lr_scheduler: Optional[modelling.LR_SCHEDULE] = None,
use_swa: bool = False, # Stochastic weight averaging can improve training genearlization
# Misc. and debugging
multithread: bool = True,
@@ -539,9 +551,12 @@ def main():
if __name__ == "__main__":
curr_time = datetime.now().strftime("%y%m%d_%H%M%S")
logging.basicConfig(level=logging.INFO, handlers=[
logging.FileHandler(f"training_{curr_time}.log"),
logging.StreamHandler()
])
logging.basicConfig(
level=logging.INFO,
handlers=[
logging.FileHandler(f"training_{curr_time}.log"),
logging.StreamHandler(),
],
)
main()

View File

@@ -41,9 +41,7 @@ from angles_and_coords import (
from custom_metrics import kl_from_empirical, wrapped_mean
import utils
ANGLES_DEFINITIONS = Literal[
"canonical", "canonical-full-angles", "canonical-minimal-angles"
]
TRIM_STRATEGIES = Literal["leftalign", "randomcrop"]
class CathConsecutiveAnglesDataset(Dataset):
@@ -255,7 +253,7 @@ class CathCanonicalAnglesDataset(Dataset):
split: Optional[Literal["train", "test", "validation"]] = None,
pad: int = 512,
min_length: int = 40, # Set to 0 to disable
trim_strategy: Literal["leftalign", "randomcrop"] = "leftalign",
trim_strategy: TRIM_STRATEGIES = "leftalign",
toy: int = 0,
zero_center: bool = False, # Center the features to have 0 mean
use_cache: bool = False, # Use/build cached computations of dihedrals and angles
@@ -743,7 +741,10 @@ class NoisedAnglesDataset(Dataset):
return noise
def __getitem__(
self, index: int, use_t_val: Optional[int] = None, ignore_zero_center: bool = False
self,
index: int,
use_t_val: Optional[int] = None,
ignore_zero_center: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Gets the i-th item in the dataset and adds noise
@@ -761,7 +762,9 @@ class NoisedAnglesDataset(Dataset):
assert (
item_index * self.timesteps + time_index == index
), f"Unexpected indices for {index} -- {item_index} {time_index}"
item = self.dset.__getitem__(item_index, ignore_zero_center=ignore_zero_center)
item = self.dset.__getitem__(
item_index, ignore_zero_center=ignore_zero_center
)
else:
item = self.dset.__getitem__(index, ignore_zero_center=ignore_zero_center)

View File

@@ -29,6 +29,11 @@ from transformers.optimization import get_linear_schedule_with_warmup
import losses
import utils
LR_SCHEDULE = Literal["OneCycleLR", "LinearWarmup"]
TIME_ENCODING = Literal["gaussian_fourier", "sinusoidal"]
LOSS_KEYS = Literal["l1", "smooth_l1"]
DECODER_HEAD = Literal["mlp", "linear"]
class GaussianFourierProjection(nn.Module):
"""
@@ -140,7 +145,9 @@ class BertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
def forward(
self, input_embeds: torch.Tensor, position_ids: torch.LongTensor,
self,
input_embeds: torch.Tensor,
position_ids: torch.LongTensor,
) -> torch.Tensor:
assert position_ids is not None, "`position_ids` must be defined"
embeddings = input_embeds
@@ -222,16 +229,16 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule):
config,
ft_is_angular: List[bool] = [False, True, True, True],
ft_names: Optional[List[str]] = None,
time_encoding: Literal["gaussian_fourier", "sinusoidal"] = "gaussian_fourier",
decoder: Literal["linear", "mlp"] = "mlp",
time_encoding: TIME_ENCODING = "gaussian_fourier",
decoder: DECODER_HEAD = "mlp",
lr: float = 5e-5,
loss: Union[Callable, Literal["l1", "smooth_l1"]] = "smooth_l1",
loss: Union[Callable, LOSS_KEYS] = "smooth_l1",
l2: float = 0.0,
l1: float = 0.0,
circle_reg: float = 0.0,
epochs: int = 1,
steps_per_epoch: int = 250, # Dummy value
lr_scheduler: Optional[Literal["OneCycleLR", "LinearWarmup"]] = None,
lr_scheduler: Optional[LR_SCHEDULE] = None,
write_preds_to_dir: Optional[str] = None,
) -> None:
"""
@@ -470,7 +477,11 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule):
if position_ids is None:
# [1, seq_length]
position_ids = (
torch.arange(seq_length,).expand(batch_size, -1).type_as(timestep)
torch.arange(
seq_length,
)
.expand(batch_size, -1)
.type_as(timestep)
)
# pl.utilities.rank_zero_debug(
@@ -661,7 +672,9 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule):
Return optimizer. Limited support for some optimizers
"""
optim = torch.optim.AdamW(
self.parameters(), lr=self.learning_rate, weight_decay=self.l2_lambda,
self.parameters(),
lr=self.learning_rate,
weight_decay=self.l2_lambda,
)
retval = {"optimizer": optim}
if self.lr_scheduler: