Update README with demo code snippet

This commit is contained in:
Kevin Wu
2022-09-28 11:13:13 -07:00
parent 353ace0c30
commit 8c60f2a644

View File

@@ -26,6 +26,7 @@ We requires some data files not packaged on Git due to their large size. These a
```bash
# Download the CATH dataset
cd data # Ensure that you are in the data subdirectory within the codebase
chmod +x download_cath.sh
./download_cath.sh
```
@@ -50,12 +51,24 @@ results/
## Pre-trained models
We provide weihts for a model trained on the CATH dataset. These weights are located under the `models/cath_pretrained` directory and are stored via Git LFS. To programmatically load these weights, you can use code defined under `foldingdiff/modelling.py` as such:
We provide weihts for a model trained on the CATH dataset. These weights are located under the `models/cath_pretrained` directory and are stored via Git LFS. The following code snippet shows how to load this model, load data, and perform a forward pass:
```python
from torch.utils.data.dataloader import DataLoader
from foldingdiff import modelling
from foldingdiff import datasets as dsets
modelling.BertForDiffusion.from_dir("models/cath_pretrained")
# Load the model
m = modelling.BertForDiffusion.from_dir("models/cath_pretrained")
# Load dataset
clean_dset = dsets.CathCanonicalAnglesOnlyDataset(pad=128, trim_strategy='randomcrop')
noised_dset = dsets.NoisedAnglesDataset(clean_dset, timesteps=1000, beta_schedule='cosine')
dl = DataLoader(noised_dset, batch_size=32, shuffle=False)
x = iter(dl).next()
# Forward pass
predicted_noise = m(x['corrupted'], x['t'], x['attn_mask'])
```
Providing this path to premade script such as for sampling is detailed below.