mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Reorganize typing literals
This commit is contained in:
43
bin/train.py
43
bin/train.py
@@ -1,5 +1,7 @@
|
||||
"""
|
||||
Training script
|
||||
Training script.
|
||||
|
||||
Example usage: python ~/protdiff/bin/train.py ~/protdiff/config_jsons/full_run_canonical_angles_only_zero_centered_1000_timesteps_reduced_len.json
|
||||
"""
|
||||
|
||||
import os, sys
|
||||
@@ -45,6 +47,11 @@ torch.manual_seed(6489)
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
# Define some typing literals
|
||||
ANGLES_DEFINITIONS = Literal[
|
||||
"canonical", "canonical-full-angles", "canonical-minimal-angles"
|
||||
]
|
||||
|
||||
|
||||
@pl.utilities.rank_zero_only
|
||||
def plot_timestep_distributions(
|
||||
@@ -83,6 +90,11 @@ def plot_kl_divergence(train_dset, plots_folder: Path) -> None:
|
||||
"""
|
||||
Plot the KL divergence over time
|
||||
"""
|
||||
# This works because the main body of this script should clean out the dir
|
||||
# between runs
|
||||
outname = plots_folder / "kl_divergence_timesteps.pdf"
|
||||
if outname.is_file():
|
||||
logging.info(f"KL divergence plot exists at {outname}; skipping...")
|
||||
kl_at_timesteps = cm.kl_from_dset(train_dset) # Shape (n_timesteps, n_features)
|
||||
n_timesteps, n_features = kl_at_timesteps.shape
|
||||
fig, axes = plt.subplots(
|
||||
@@ -98,14 +110,14 @@ def plot_kl_divergence(train_dset, plots_folder: Path) -> None:
|
||||
fig.suptitle(
|
||||
f"KL(empirical || Gaussian) over timesteps={train_dset.timesteps}", y=1.05
|
||||
)
|
||||
fig.savefig(plots_folder / "kl_divergence_timesteps.pdf", bbox_inches="tight")
|
||||
fig.savefig(outname, bbox_inches="tight")
|
||||
|
||||
|
||||
def get_train_valid_test_sets(
|
||||
angles_definitions: datasets.ANGLES_DEFINITIONS = "canonical-full-angles",
|
||||
angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles",
|
||||
max_seq_len: int = 512,
|
||||
min_seq_len: int = 0,
|
||||
seq_trim_strategy: Literal["leftalign", "randomcrop"] = "leftalign",
|
||||
seq_trim_strategy: datasets.TRIM_STRATEGIES = "leftalign",
|
||||
timesteps: int = 250,
|
||||
variance_schedule: SCHEDULES = "linear",
|
||||
zero_center: bool = False,
|
||||
@@ -269,16 +281,16 @@ def train(
|
||||
# Controls output
|
||||
results_dir: str = "./results",
|
||||
# Controls data loading and noising process
|
||||
angles_definitions: datasets.ANGLES_DEFINITIONS = "canonical",
|
||||
angles_definitions: ANGLES_DEFINITIONS = "canonical",
|
||||
max_seq_len: int = 512,
|
||||
min_seq_len: int = 0, # 0 means no filtering based on min sequence length
|
||||
trim_strategy: Literal["leftalign", "randomcrop"] = "leftalign",
|
||||
trim_strategy: datasets.TRIM_STRATEGIES = "leftalign",
|
||||
zero_center: bool = False,
|
||||
timesteps: int = 250,
|
||||
variance_schedule: SCHEDULES = "linear", # cosine better on single angle toy test
|
||||
variance_scale: float = 1.0,
|
||||
# Related to model architecture
|
||||
time_encoding: Literal["gaussian_fourier", "sinusoidal"] = "gaussian_fourier",
|
||||
time_encoding: modelling.TIME_ENCODING = "gaussian_fourier",
|
||||
num_hidden_layers: int = 12, # Default 12
|
||||
hidden_size: int = 384, # Default 768
|
||||
intermediate_size: int = 768, # Default 3072
|
||||
@@ -287,19 +299,19 @@ def train(
|
||||
"absolute", "relative_key", "relative_key_query"
|
||||
] = "absolute", # relative_key = https://arxiv.org/pdf/1803.02155.pdf | relative_key_query = https://arxiv.org/pdf/2009.13658.pdf
|
||||
dropout_p: float = 0.1, # Default 0.1, can disable for debugging
|
||||
decoder: Literal["mlp", "linear"] = "mlp",
|
||||
decoder: modelling.DECODER_HEAD = "mlp",
|
||||
# Related to training strategy
|
||||
gradient_clip: float = 1.0, # From BERT trainer
|
||||
batch_size: int = 64,
|
||||
lr: float = 5e-5, # Default lr for huggingface BERT trainer
|
||||
loss: Literal["l1", "smooth_l1"] = "smooth_l1",
|
||||
loss: modelling.LOSS_KEYS = "smooth_l1",
|
||||
l2_norm: float = 0.0, # AdamW default has 0.01 L2 regularization, but BERT trainer uses 0.0
|
||||
l1_norm: float = 0.0,
|
||||
circle_reg: float = 0.0,
|
||||
min_epochs: Optional[int] = None,
|
||||
max_epochs: int = 10000,
|
||||
early_stop_patience: int = 0, # Set to 0 to disable early stopping
|
||||
lr_scheduler: Optional[Literal["OneCycleLR", "LinearWarmup"]] = None,
|
||||
lr_scheduler: Optional[modelling.LR_SCHEDULE] = None,
|
||||
use_swa: bool = False, # Stochastic weight averaging can improve training genearlization
|
||||
# Misc. and debugging
|
||||
multithread: bool = True,
|
||||
@@ -539,9 +551,12 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
curr_time = datetime.now().strftime("%y%m%d_%H%M%S")
|
||||
logging.basicConfig(level=logging.INFO, handlers=[
|
||||
logging.FileHandler(f"training_{curr_time}.log"),
|
||||
logging.StreamHandler()
|
||||
])
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
handlers=[
|
||||
logging.FileHandler(f"training_{curr_time}.log"),
|
||||
logging.StreamHandler(),
|
||||
],
|
||||
)
|
||||
|
||||
main()
|
||||
|
||||
@@ -41,9 +41,7 @@ from angles_and_coords import (
|
||||
from custom_metrics import kl_from_empirical, wrapped_mean
|
||||
import utils
|
||||
|
||||
ANGLES_DEFINITIONS = Literal[
|
||||
"canonical", "canonical-full-angles", "canonical-minimal-angles"
|
||||
]
|
||||
TRIM_STRATEGIES = Literal["leftalign", "randomcrop"]
|
||||
|
||||
|
||||
class CathConsecutiveAnglesDataset(Dataset):
|
||||
@@ -255,7 +253,7 @@ class CathCanonicalAnglesDataset(Dataset):
|
||||
split: Optional[Literal["train", "test", "validation"]] = None,
|
||||
pad: int = 512,
|
||||
min_length: int = 40, # Set to 0 to disable
|
||||
trim_strategy: Literal["leftalign", "randomcrop"] = "leftalign",
|
||||
trim_strategy: TRIM_STRATEGIES = "leftalign",
|
||||
toy: int = 0,
|
||||
zero_center: bool = False, # Center the features to have 0 mean
|
||||
use_cache: bool = False, # Use/build cached computations of dihedrals and angles
|
||||
@@ -743,7 +741,10 @@ class NoisedAnglesDataset(Dataset):
|
||||
return noise
|
||||
|
||||
def __getitem__(
|
||||
self, index: int, use_t_val: Optional[int] = None, ignore_zero_center: bool = False
|
||||
self,
|
||||
index: int,
|
||||
use_t_val: Optional[int] = None,
|
||||
ignore_zero_center: bool = False,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Gets the i-th item in the dataset and adds noise
|
||||
@@ -761,7 +762,9 @@ class NoisedAnglesDataset(Dataset):
|
||||
assert (
|
||||
item_index * self.timesteps + time_index == index
|
||||
), f"Unexpected indices for {index} -- {item_index} {time_index}"
|
||||
item = self.dset.__getitem__(item_index, ignore_zero_center=ignore_zero_center)
|
||||
item = self.dset.__getitem__(
|
||||
item_index, ignore_zero_center=ignore_zero_center
|
||||
)
|
||||
else:
|
||||
item = self.dset.__getitem__(index, ignore_zero_center=ignore_zero_center)
|
||||
|
||||
|
||||
@@ -29,6 +29,11 @@ from transformers.optimization import get_linear_schedule_with_warmup
|
||||
import losses
|
||||
import utils
|
||||
|
||||
LR_SCHEDULE = Literal["OneCycleLR", "LinearWarmup"]
|
||||
TIME_ENCODING = Literal["gaussian_fourier", "sinusoidal"]
|
||||
LOSS_KEYS = Literal["l1", "smooth_l1"]
|
||||
DECODER_HEAD = Literal["mlp", "linear"]
|
||||
|
||||
|
||||
class GaussianFourierProjection(nn.Module):
|
||||
"""
|
||||
@@ -140,7 +145,9 @@ class BertEmbeddings(nn.Module):
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
|
||||
def forward(
|
||||
self, input_embeds: torch.Tensor, position_ids: torch.LongTensor,
|
||||
self,
|
||||
input_embeds: torch.Tensor,
|
||||
position_ids: torch.LongTensor,
|
||||
) -> torch.Tensor:
|
||||
assert position_ids is not None, "`position_ids` must be defined"
|
||||
embeddings = input_embeds
|
||||
@@ -222,16 +229,16 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule):
|
||||
config,
|
||||
ft_is_angular: List[bool] = [False, True, True, True],
|
||||
ft_names: Optional[List[str]] = None,
|
||||
time_encoding: Literal["gaussian_fourier", "sinusoidal"] = "gaussian_fourier",
|
||||
decoder: Literal["linear", "mlp"] = "mlp",
|
||||
time_encoding: TIME_ENCODING = "gaussian_fourier",
|
||||
decoder: DECODER_HEAD = "mlp",
|
||||
lr: float = 5e-5,
|
||||
loss: Union[Callable, Literal["l1", "smooth_l1"]] = "smooth_l1",
|
||||
loss: Union[Callable, LOSS_KEYS] = "smooth_l1",
|
||||
l2: float = 0.0,
|
||||
l1: float = 0.0,
|
||||
circle_reg: float = 0.0,
|
||||
epochs: int = 1,
|
||||
steps_per_epoch: int = 250, # Dummy value
|
||||
lr_scheduler: Optional[Literal["OneCycleLR", "LinearWarmup"]] = None,
|
||||
lr_scheduler: Optional[LR_SCHEDULE] = None,
|
||||
write_preds_to_dir: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -470,7 +477,11 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule):
|
||||
if position_ids is None:
|
||||
# [1, seq_length]
|
||||
position_ids = (
|
||||
torch.arange(seq_length,).expand(batch_size, -1).type_as(timestep)
|
||||
torch.arange(
|
||||
seq_length,
|
||||
)
|
||||
.expand(batch_size, -1)
|
||||
.type_as(timestep)
|
||||
)
|
||||
|
||||
# pl.utilities.rank_zero_debug(
|
||||
@@ -661,7 +672,9 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule):
|
||||
Return optimizer. Limited support for some optimizers
|
||||
"""
|
||||
optim = torch.optim.AdamW(
|
||||
self.parameters(), lr=self.learning_rate, weight_decay=self.l2_lambda,
|
||||
self.parameters(),
|
||||
lr=self.learning_rate,
|
||||
weight_decay=self.l2_lambda,
|
||||
)
|
||||
retval = {"optimizer": optim}
|
||||
if self.lr_scheduler:
|
||||
|
||||
Reference in New Issue
Block a user