mirror of
https://github.com/gcorso/DiffDock.git
synced 2026-06-04 18:04:23 +08:00
52 lines
2.1 KiB
Python
52 lines
2.1 KiB
Python
from rdkit.Chem.rdmolfiles import MolToPDBBlock, MolToPDBFile
|
|
import rdkit.Chem
|
|
from rdkit import Geometry
|
|
from collections import defaultdict
|
|
import copy
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
class PDBFile:
|
|
def __init__(self, mol):
|
|
self.parts = defaultdict(dict)
|
|
self.mol = copy.deepcopy(mol)
|
|
[self.mol.RemoveConformer(j) for j in range(mol.GetNumConformers()) if j]
|
|
def add(self, coords, order, part=0, repeat=1):
|
|
if type(coords) in [rdkit.Chem.Mol, rdkit.Chem.RWMol]:
|
|
block = MolToPDBBlock(coords).split('\n')[:-2]
|
|
self.parts[part][order] = {'block': block, 'repeat': repeat}
|
|
return
|
|
elif type(coords) is np.ndarray:
|
|
coords = coords.astype(np.float64)
|
|
elif type(coords) is torch.Tensor:
|
|
coords = coords.double().numpy()
|
|
for i in range(coords.shape[0]):
|
|
self.mol.GetConformer(0).SetAtomPosition(i, Geometry.Point3D(coords[i, 0], coords[i, 1], coords[i, 2]))
|
|
block = MolToPDBBlock(self.mol).split('\n')[:-2]
|
|
self.parts[part][order] = {'block': block, 'repeat': repeat}
|
|
|
|
def write(self, path=None, limit_parts=None):
|
|
is_first = True
|
|
str_ = ''
|
|
for part in sorted(self.parts.keys()):
|
|
if limit_parts and part >= limit_parts:
|
|
break
|
|
part = self.parts[part]
|
|
keys_positive = sorted(filter(lambda x: x >=0, part.keys()))
|
|
keys_negative = sorted(filter(lambda x: x < 0, part.keys()))
|
|
keys = list(keys_positive) + list(keys_negative)
|
|
for key in keys:
|
|
block = part[key]['block']
|
|
times = part[key]['repeat']
|
|
for _ in range(times):
|
|
if not is_first:
|
|
block = [line for line in block if 'CONECT' not in line]
|
|
is_first = False
|
|
str_ += 'MODEL\n'
|
|
str_ += '\n'.join(block)
|
|
str_ += '\nENDMDL\n'
|
|
if not path:
|
|
return str_
|
|
with open(path, 'w') as f:
|
|
f.write(str_) |