mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Update README with demo code snippet
This commit is contained in:
17
README.md
17
README.md
@@ -26,6 +26,7 @@ We requires some data files not packaged on Git due to their large size. These a
|
||||
```bash
|
||||
# Download the CATH dataset
|
||||
cd data # Ensure that you are in the data subdirectory within the codebase
|
||||
chmod +x download_cath.sh
|
||||
./download_cath.sh
|
||||
```
|
||||
|
||||
@@ -50,12 +51,24 @@ results/
|
||||
|
||||
## Pre-trained models
|
||||
|
||||
We provide weihts for a model trained on the CATH dataset. These weights are located under the `models/cath_pretrained` directory and are stored via Git LFS. To programmatically load these weights, you can use code defined under `foldingdiff/modelling.py` as such:
|
||||
We provide weihts for a model trained on the CATH dataset. These weights are located under the `models/cath_pretrained` directory and are stored via Git LFS. The following code snippet shows how to load this model, load data, and perform a forward pass:
|
||||
|
||||
```python
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from foldingdiff import modelling
|
||||
from foldingdiff import datasets as dsets
|
||||
|
||||
modelling.BertForDiffusion.from_dir("models/cath_pretrained")
|
||||
# Load the model
|
||||
m = modelling.BertForDiffusion.from_dir("models/cath_pretrained")
|
||||
|
||||
# Load dataset
|
||||
clean_dset = dsets.CathCanonicalAnglesOnlyDataset(pad=128, trim_strategy='randomcrop')
|
||||
noised_dset = dsets.NoisedAnglesDataset(clean_dset, timesteps=1000, beta_schedule='cosine')
|
||||
dl = DataLoader(noised_dset, batch_size=32, shuffle=False)
|
||||
x = iter(dl).next()
|
||||
|
||||
# Forward pass
|
||||
predicted_noise = m(x['corrupted'], x['t'], x['attn_mask'])
|
||||
```
|
||||
|
||||
Providing this path to premade script such as for sampling is detailed below.
|
||||
|
||||
Reference in New Issue
Block a user