mirror of
https://github.com/tsa87/cgflow.git
synced 2026-06-04 12:14:22 +08:00
bugfix in generator
This commit is contained in:
13
README.md
13
README.md
@@ -101,9 +101,18 @@ In this setting, we directly using the final predicted pose from pose prediction
|
||||
|
||||
### 2. Zero-shot Pocket-conditional Generation
|
||||
|
||||
Please refer to the `experiments/scripts/exp3Z_sampling.py` for example code for zero-shot pocket-conditional generation. You can download the pretrained model weights from [here](https://drive.google.com/drive/folders/1gBz-xTw6gf5nwjcB4ZPX63Y4ebGoSWkU?usp=sharing).
|
||||
You can download the pretrained model weights from [here](https://drive.google.com/drive/folders/1gBz-xTw6gf5nwjcB4ZPX63Y4ebGoSWkU?usp=sharing).
|
||||
|
||||
You can swap the pose prediction model weights to either `crossdocked2020_till_end.ckpt` which is trained on CrossDock2020 or `plinder_till_end.ckpt` which is trained on Plinder.
|
||||
```bash
|
||||
python scripts/multi_pocket/sample.py \
|
||||
--protein_path data/examples/aldh1_protein.pdb \
|
||||
--ref_ligand_path data/examples/aldh1_ligand.mol2 \
|
||||
--env_dir "<ENV_DIR>" \
|
||||
--device cuda \
|
||||
--save_dir ./out/ \
|
||||
--flow_model ./weights/cgflow_crossdock.ckpt \
|
||||
--gfn_model ./weights/3dsynthflow_tacogfn.ckpt
|
||||
```
|
||||
|
||||
### 3. Fine-tuning the pocket-conditional model
|
||||
|
||||
|
||||
119
scripts/multi_pocket/sample.py
Normal file
119
scripts/multi_pocket/sample.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Pocket
|
||||
parser.add_argument("--protein_path", type=Path, required=True, help="Directory containing environment data.")
|
||||
parser.add_argument(
|
||||
"--ref_ligand_path", type=Path, help="Path to the reference ligand file to determine pocket center."
|
||||
)
|
||||
parser.add_argument("--center", type=float, nargs=3, help="Pocket center coordinates.")
|
||||
|
||||
# Environment
|
||||
parser.add_argument("--env_dir", type=Path, required=True, help="Directory for environment.")
|
||||
|
||||
# Generation
|
||||
parser.add_argument("--save_dir", type=Path, required=True, help="Directory to save sampled results.")
|
||||
parser.add_argument("--num_samples", type=int, help="Number of samples per pocket.", default=100)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=int,
|
||||
nargs=2,
|
||||
help="Temperature (Exploration-Exploitation Trade-off). example: `--temp 1 64` (unif)",
|
||||
default=[16, 48],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subsampling_ratio",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Subsampling ratio for action space; Memory-variance trade-off.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="Random seed for reproducibility.")
|
||||
parser.add_argument("--batch_size", type=int, default=50, help="Batch size for training.")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="Device to use for computation.")
|
||||
|
||||
# model
|
||||
parser.add_argument(
|
||||
"--flow_model",
|
||||
type=Path,
|
||||
help="Path to the flow model checkpoint.",
|
||||
default="./weights/cgflow_crossdock.ckpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gfn_model",
|
||||
type=Path,
|
||||
help="Path to the GFN model checkpoint.",
|
||||
default="./weights/3dsynthflow_tacogfn.ckpt",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.center is None and args.ref_ligand_path is None:
|
||||
parser.error("Either --center or --ref_ligand_path must be provided to determine the pocket center.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from rdkit import Chem
|
||||
|
||||
from synthflow.config import Config, init_empty
|
||||
from synthflow.pocket_conditional.sampler import PocketConditionalSampler
|
||||
|
||||
args = parse_args()
|
||||
|
||||
# Set seed
|
||||
set_seed(args.seed)
|
||||
|
||||
# Create sampler
|
||||
config = init_empty(Config())
|
||||
config.env_dir = args.env_dir
|
||||
config.cgflow.ckpt_path = args.flow_model
|
||||
config.algo.action_subsampling.sampling_ratio = args.subsampling_ratio
|
||||
config.algo.max_nodes = 40
|
||||
config.algo.num_from_policy = args.batch_size # batch size
|
||||
sampler = PocketConditionalSampler(config, args.gfn_model, args.device)
|
||||
sampler.update_temperature("uniform", args.temperature)
|
||||
|
||||
save_dir = Path(args.save_dir)
|
||||
save_dir.mkdir(exist_ok=True)
|
||||
smiles_path = save_dir / "smiles.csv"
|
||||
pose_path = save_dir / "pose.sdf"
|
||||
|
||||
protein_path = args.protein_path
|
||||
if args.ref_ligand_path is not None:
|
||||
sampler.set_pocket(protein_path, ref_ligand_path=args.ref_ligand_path)
|
||||
else:
|
||||
assert args.center is not None
|
||||
sampler.set_pocket(protein_path, center=args.center)
|
||||
|
||||
res = sampler.sample(100)
|
||||
|
||||
with open(smiles_path, "w") as w:
|
||||
w.write(",SMILES\n")
|
||||
for idx, sample in enumerate(res):
|
||||
smiles = sample["smiles"]
|
||||
w.write(f"{idx},{smiles}\n")
|
||||
|
||||
with Chem.SDWriter(str(pose_path)) as w:
|
||||
for i, sample in enumerate(res):
|
||||
mol = sample["mol"]
|
||||
mol.SetIntProp("sample_idx", i)
|
||||
w.write(mol)
|
||||
@@ -2,6 +2,25 @@ from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
class TBVariant(IntEnum):
|
||||
"""The Trajectory Balance variant to use."""
|
||||
|
||||
TB = 0
|
||||
SubTB = 1
|
||||
DB = 2
|
||||
|
||||
|
||||
class NLoss(IntEnum):
|
||||
"""See algo.trajectory_balance.TrajectoryBalance for details."""
|
||||
|
||||
none = 0
|
||||
Transition = 1
|
||||
SubTB1 = 2
|
||||
TermTB1 = 3
|
||||
StartTB1 = 4
|
||||
TB = 5
|
||||
|
||||
|
||||
class Backward(IntEnum):
|
||||
"""
|
||||
See algo.trajectory_balance.TrajectoryBalance for details.
|
||||
|
||||
@@ -58,7 +58,7 @@ class TrajectoryBalance(GFNAlgorithm):
|
||||
self.env = env
|
||||
self.global_cfg = cfg
|
||||
self.cfg = cfg.algo.tb
|
||||
self.max_nodes = cfg.algo.max_nodes
|
||||
self.max_nodes = 9
|
||||
self.max_len = cfg.algo.max_len
|
||||
self.length_normalize_losses = cfg.algo.tb.do_length_normalize
|
||||
|
||||
|
||||
@@ -141,7 +141,7 @@ class RxnFlowSampler:
|
||||
self.ctx = SynthesisEnvContext(self.env, num_cond_dim=self.task.num_cond_dim)
|
||||
|
||||
def setup_model(self):
|
||||
self.model = RxnFlow(self.ctx, self.cfg, do_bck=False, num_graph_out=self.cfg.algo.tb.do_predict_n + 1)
|
||||
self.model = RxnFlow(self.ctx, self.cfg)
|
||||
|
||||
def setup_algo(self):
|
||||
assert self.cfg.algo.method == "TB"
|
||||
|
||||
Reference in New Issue
Block a user