mirror of
https://github.com/AngxiaoYue/ReQFlow.git
synced 2026-06-04 20:24:22 +08:00
75 lines
2.4 KiB
Python
75 lines
2.4 KiB
Python
import numpy as np
|
|
import os
|
|
import re
|
|
from data import protein
|
|
from openfold.utils import rigid_utils
|
|
|
|
|
|
Rigid = rigid_utils.Rigid
|
|
|
|
|
|
def create_full_prot(
|
|
atom37: np.ndarray,
|
|
atom37_mask: np.ndarray,
|
|
aatype=None,
|
|
b_factors=None,
|
|
):
|
|
assert atom37.ndim == 3
|
|
assert atom37.shape[-1] == 3
|
|
assert atom37.shape[-2] == 37
|
|
n = atom37.shape[0]
|
|
residue_index = np.arange(n)
|
|
chain_index = np.zeros(n)
|
|
if b_factors is None:
|
|
b_factors = np.zeros([n, 37])
|
|
if aatype is None:
|
|
aatype = np.zeros(n, dtype=int)
|
|
return protein.Protein(
|
|
atom_positions=atom37,
|
|
atom_mask=atom37_mask,
|
|
aatype=aatype,
|
|
residue_index=residue_index,
|
|
chain_index=chain_index,
|
|
b_factors=b_factors)
|
|
|
|
|
|
def write_prot_to_pdb(
|
|
prot_pos: np.ndarray,
|
|
file_path: str,
|
|
aatype: np.ndarray=None,
|
|
overwrite=False,
|
|
no_indexing=False,
|
|
b_factors=None,
|
|
):
|
|
if overwrite:
|
|
max_existing_idx = 0
|
|
else:
|
|
file_dir = os.path.dirname(file_path)
|
|
file_name = os.path.basename(file_path).strip('.pdb')
|
|
existing_files = [x for x in os.listdir(file_dir) if file_name in x]
|
|
max_existing_idx = max([
|
|
int(re.findall(r'_(\d+).pdb', x)[0]) for x in existing_files if re.findall(r'_(\d+).pdb', x)
|
|
if re.findall(r'_(\d+).pdb', x)] + [0])
|
|
if not no_indexing:
|
|
save_path = file_path.replace('.pdb', '') + f'_{max_existing_idx+1}.pdb'
|
|
else:
|
|
save_path = file_path
|
|
with open(save_path, 'w') as f:
|
|
if prot_pos.ndim == 4:
|
|
for t, pos37 in enumerate(prot_pos):
|
|
atom37_mask = np.sum(np.abs(pos37), axis=-1) > 1e-7
|
|
prot = create_full_prot(
|
|
pos37, atom37_mask, aatype=aatype, b_factors=b_factors)
|
|
pdb_prot = protein.to_pdb(prot, model=t + 1, add_end=False)
|
|
f.write(pdb_prot)
|
|
elif prot_pos.ndim == 3:
|
|
atom37_mask = np.sum(np.abs(prot_pos), axis=-1) > 1e-7
|
|
prot = create_full_prot(
|
|
prot_pos, atom37_mask, aatype=aatype, b_factors=b_factors)
|
|
pdb_prot = protein.to_pdb(prot, model=1, add_end=False)
|
|
f.write(pdb_prot)
|
|
else:
|
|
raise ValueError(f'Invalid positions shape {prot_pos.shape}')
|
|
f.write('END')
|
|
return save_path
|