From 9ca513f5e1fafbe6583b43c036b74a0b8f60ab86 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Thu, 9 Mar 2023 16:02:30 -0800 Subject: [PATCH] Threading updates --- scripts/gromacs/gromacs.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/scripts/gromacs/gromacs.py b/scripts/gromacs/gromacs.py index ef2d19b..895b112 100644 --- a/scripts/gromacs/gromacs.py +++ b/scripts/gromacs/gromacs.py @@ -18,7 +18,6 @@ import logging import shlex import subprocess import shutil -import multiprocessing as mp GRO_FILE_DIR = os.path.join(os.path.dirname(__file__), "mdp") @@ -28,6 +27,7 @@ def run_gromacs( outdir: str = os.getcwd(), gmx: str = "gmx", gro_file_dir: str = GRO_FILE_DIR, + n_threads: int = 8, ) -> float: """ Run GROMACS on a PDB file @@ -86,9 +86,7 @@ def run_gromacs( # come to "room temperature" grompp_cmd = f"{gmx} grompp -f {gro_file_dir}nvt.mdp -c em.gro -r em.gro -p topol.top -o nvt.tpr" subprocess.call(shlex.split(grompp_cmd)) - nvt_cmd = ( - f"{gmx} mdrun -ntmpi 1 -ntomp {mp.cpu_count() - 1} -nb gpu -pin on -deffnm nvt" - ) + nvt_cmd = f"{gmx} mdrun -ntmpi 1 -ntomp {n_threads - 1} -nb gpu -pin on -deffnm nvt" subprocess.call(shlex.split(nvt_cmd)) # NPT @@ -96,16 +94,14 @@ def run_gromacs( f"{gmx} grompp -f {gro_file_dir}npt.mdp -c nvt.gro -o npt.tpr -p topol.top" ) subprocess.call(shlex.split(grompp_cmd)) - npt_cmd = ( - f"{gmx} mdrun -ntmpi 1 -ntomp {mp.cpu_count() - 1} -nb gpu -pin on -deffnm npt" - ) + npt_cmd = f"{gmx} mdrun -ntmpi 1 -ntomp {n_threads - 1} -nb gpu -pin on -deffnm npt" subprocess.call(shlex.split(npt_cmd)) # Production run grompp_cmd = f"{gmx} grompp -f {gro_file_dir}md.mdp -c npt.gro -t npt.cpt -p topol.top -o prod.tpr" subprocess.call(shlex.split(grompp_cmd)) prod_cmd = ( - f"{gmx} mdrun -ntmpi 1 -ntomp {mp.cpu_count() - 1} -nb gpu -pin on -deffnm prod" + f"{gmx} mdrun -ntmpi 1 -ntomp {n_threads - 1} -nb gpu -pin on -deffnm prod" ) subprocess.call(shlex.split(prod_cmd)) @@ -156,6 +152,9 @@ def build_parser(): parser.add_argument( "--mdp", type=str, default=GRO_FILE_DIR, help="MDP file directory" ) + parser.add_argument( + "--threads", type=int, default=8, help="Threads (minimum 2)" + ) return parser @@ -170,7 +169,11 @@ def main(): with tempfile.TemporaryDirectory() as tmpdir: os.chdir(tmpdir) energy = run_gromacs( - args.pdb_file, tmpdir, gmx=args.gmxbin, gro_file_dir=args.mdp + args.pdb_file, + tmpdir, + gmx=args.gmxbin, + gro_file_dir=args.mdp, + n_threads=args.threads, ) for file in os.listdir(tmpdir): logging.debug(f"GROMACS file: {file}")