mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Speed up sampling by truncating input padding regions.
This commit is contained in:
@@ -281,6 +281,7 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--testcomparison", action="store_true", help="Run comparison against test set"
|
||||
)
|
||||
parser.add_argument("--nopsea", action="store_true", help="Skip PSEA calculations")
|
||||
parser.add_argument("--seed", type=int, default=SEED, help="Random seed")
|
||||
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use")
|
||||
return parser
|
||||
@@ -453,18 +454,19 @@ def main() -> None:
|
||||
)
|
||||
|
||||
# Generate plots of secondary structure co-occurrence
|
||||
make_ss_cooccurrence_plot(
|
||||
pdb_files,
|
||||
str(outdir / "plots" / "ss_cooccurrence_sampled.pdf"),
|
||||
threads=multiprocessing.cpu_count(),
|
||||
)
|
||||
if args.testcomparison:
|
||||
if not args.nopsea:
|
||||
make_ss_cooccurrence_plot(
|
||||
test_dset.filenames,
|
||||
str(outdir / "plots" / "ss_cooccurrence_test.pdf"),
|
||||
max_seq_len=test_dset.dset.pad,
|
||||
pdb_files,
|
||||
str(outdir / "plots" / "ss_cooccurrence_sampled.pdf"),
|
||||
threads=multiprocessing.cpu_count(),
|
||||
)
|
||||
if args.testcomparison:
|
||||
make_ss_cooccurrence_plot(
|
||||
test_dset.filenames,
|
||||
str(outdir / "plots" / "ss_cooccurrence_test.pdf"),
|
||||
max_seq_len=test_dset.dset.pad,
|
||||
threads=multiprocessing.cpu_count(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -54,8 +54,8 @@ def p_sample(
|
||||
|
||||
# Create the attention mask
|
||||
attn_mask = torch.zeros(x.shape[:2], device=x.device)
|
||||
for i, l in enumerate(seq_lens):
|
||||
attn_mask[i, :l] = 1.0
|
||||
for i, length in enumerate(seq_lens):
|
||||
attn_mask[i, :length] = 1.0
|
||||
|
||||
# Equation 11 in the paper
|
||||
# Use our model (noise predictor) to predict the mean
|
||||
@@ -140,6 +140,7 @@ def sample(
|
||||
batch_size: int = 512,
|
||||
feature_key: str = "angles",
|
||||
disable_pbar: bool = False,
|
||||
trim_to_length: bool = True, # Trim padding regions to reduce memory
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Sample from the given model. Use the train_dset to generate noise to sample
|
||||
@@ -157,6 +158,10 @@ def sample(
|
||||
# Process each batch
|
||||
if sweep_lengths is not None:
|
||||
sweep_min, sweep_max = sweep_lengths
|
||||
if not sweep_min < sweep_max:
|
||||
raise ValueError(
|
||||
f"Minimum length {sweep_min} must be less than maximum {sweep_max}"
|
||||
)
|
||||
logging.info(
|
||||
f"Sweeping from {sweep_min}-{sweep_max} with {n} examples at each length"
|
||||
)
|
||||
@@ -177,6 +182,11 @@ def sample(
|
||||
noise = train_dset.sample_noise(
|
||||
torch.zeros((batch, train_dset.pad, model.n_inputs), dtype=torch.float32)
|
||||
)
|
||||
|
||||
# Trim things that are beyond the length of what we are generating
|
||||
if trim_to_length:
|
||||
noise = noise[:, : max(this_lengths), :]
|
||||
|
||||
# Produces (timesteps, batch_size, seq_len, n_ft)
|
||||
sampled = p_sample_loop(
|
||||
model=model,
|
||||
@@ -255,7 +265,7 @@ def sample_simple(
|
||||
|
||||
|
||||
def _score_angles(
|
||||
reconst_angles:pd.DataFrame, truth_angles:pd.DataFrame, truth_coords_pdb: str
|
||||
reconst_angles: pd.DataFrame, truth_angles: pd.DataFrame, truth_coords_pdb: str
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Helper function to scores sets of angles
|
||||
@@ -348,6 +358,7 @@ def get_reconstruction_error(
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
s = sample_simple("wukevin/foldingdiff_cath", n=1, sweep_lengths=(50, 55))
|
||||
s = sample_simple("wukevin/foldingdiff_cath", n=1, sweep_lengths=(50, 51))
|
||||
for i, x in enumerate(s):
|
||||
print(x.shape)
|
||||
print(x)
|
||||
|
||||
Reference in New Issue
Block a user