import unittest import subprocess import glob import datetime import os import torch from shutil import copyfile from rfdiffusion.inference import utils as iu from rfdiffusion.util import calc_rmsd import sys, json script_dir = os.path.dirname(os.path.abspath(__file__)) class TestSubmissionCommands(unittest.TestCase): """ Test harness for checking that commands in the examples folder, when run in deterministic mode, produce the same output as the reference outputs. Requirements: - example command must be written on a single line - outputs must be written to example_outputs folder - needs to be run on the same hardware as the reference outputs (A100 GPU) For speed, we only run the first 2 steps of diffusion, and set inference.num_designs=1 This means that outputs DO NOT look like proteins, but we can still check that the outputs are the same as the reference outputs. """ failed_tests = [] # number of chunks to split examples into total_chunks = 1 # which chunk to run chunk_index = 1 out_f = None results = {} exec_status = {} @classmethod def setUpClass(cls): """ Class-level setup: Grabs files from the examples folder, discover & rewrite example commands, then execute them once. """ submissions = glob.glob(f"{script_dir}/../examples/*.sh") # get datetime for output folder, in YYYY_MM_DD_HH_MM_SS format chunks = cls.total_chunks idx = cls.chunk_index if chunks < 1: raise ValueError("total_chunks must be at least 1") if idx < 1 or idx > chunks: raise ValueError( "chunk_index must be between 1 and total_chunks (inclusive)" ) if chunks > 1: submissions = [ submissions[i] for i in range(len(submissions)) if i % chunks == (idx - 1) ] print( f"Running chunk {idx}/{chunks}, {len(submissions)} submissions to run" ) if not submissions: raise ValueError("No submissions selected for chunk {idx} of {chunks}") now = datetime.datetime.now() now = now.strftime("%Y_%m_%d_%H_%M_%S") cls.out_f = f"{script_dir}/tests_{now}_{idx}" os.mkdir(cls.out_f) # Make sure we have access to all the relevant files exclude_dirs = ["outputs", "example_outputs"] for filename in os.listdir(f"{script_dir}/../examples"): if ( filename not in exclude_dirs and not os.path.exists(os.path.join(script_dir, filename)) and os.path.isdir(os.path.join(f"{script_dir}/../examples", filename)) ): try: os.symlink( os.path.join(f"{script_dir}/../examples", filename), os.path.join(script_dir, filename), ) except FileExistsError: pass for submission in submissions: cls._write_command(submission, cls.out_f) print( f"Running commands in {cls.out_f}, two steps of diffusion, deterministic=True" ) cls.results = {} cls.exec_status = {} for bash_file in sorted(glob.glob(f"{cls.out_f}/*.sh"), reverse=False): test_name = os.path.basename(bash_file)[: -len(".sh")] res, output = execute( f"Running {test_name}", f"bash {bash_file}", return_="tuple", add_message_and_command_line_to_output=True, ) cls.exec_status[test_name] = (res, output) cls.results[test_name] = dict( state="failed" if res else "passed", log=output, ) # subprocess.run(["bash", bash_file], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # subprocess.run(["bash", bash_file]) def test_examples_run_without_errors(self): for name, (exit_code, output) in sorted(self.__class__.exec_status.items()): with self.subTest(example=name): if exit_code != 0: self.__class__.failed_tests.append(f"{name}") self.assertEqual( exit_code, 0, msg=f"Example '{name}' exited with {exit_code}\n{output}", ) sys.stderr.write("\n==== EXAMPLE FAILURE SUMMARY ====\n") for line in self.__class__.failed_tests: sys.stderr.write(f" - {line}\n") sys.stderr.write("=========================\n\n") sys.stderr.flush() def test_commands(self): """ Runs all the commands in the test_f folder """ reference = f"{script_dir}/reference_outputs" os.makedirs(reference, exist_ok=True) test_files = glob.glob(f"{self.__class__.out_f}/example_outputs/*pdb") print(f"{self.__class__.out_f=} {test_files=}") # first check that we have the right number of outputs # self.assertEqual(len(test_files), len(glob.glob(f"{self.out_f}/*.sh"))), "One or more of the example commands didn't produce an output (check the example command is formatted correctly)" result = self.defaultTestResult() for test_file in test_files: with self.subTest(test_file=test_file): test_pdb = iu.parse_pdb(test_file) if not os.path.exists(f"{reference}/{os.path.basename(test_file)}"): copyfile(test_file, f"{reference}/{os.path.basename(test_file)}") print( f"Created reference file {reference}/{os.path.basename(test_file)}" ) else: ref_pdb = iu.parse_pdb(f"{reference}/{os.path.basename(test_file)}") rmsd = calc_rmsd( test_pdb["xyz"][:, :3].reshape(-1, 3), ref_pdb["xyz"][:, :3].reshape(-1, 3), )[0] try: self.assertAlmostEqual(rmsd, 0, 2) result.addSuccess(self) print(f"Subtest {test_file} passed") state = "passed" log = f"Subtest {test_file} passed" except AssertionError as e: result.addFailure(self, e) print(f"Subtest {test_file} failed") state = "failed" log = f"Subtest {test_file} failed:\n{e!r}" self.results["pdb-diff." + test_file.rpartition("/")[-1]] = dict( state=state, log=log ) with open(".results.json", "w") as f: json.dump(self.results, f, sort_keys=True, indent=2) self.assertTrue(result.wasSuccessful(), "One or more subtests failed") @classmethod def _write_command(cls, bash_file, test_f) -> None: """ Takes a bash file from the examples folder, and writes a version of it to the test_f folder. It appends to the python command the following arguments: inference.deterministic=True if partial_T is in the command, it grabs partial T and sets: inference.final_step=partial_T-2 else: inference.final_step=48 """ out_lines = [] command_lines = [] in_command = False with open(bash_file, "r") as f: for line in f: stripped = line.strip() if stripped.startswith("python") or stripped.startswith("../"): in_command = True if in_command: # Remove trailing line continuation slashes if stripped.endswith("\\"): command_lines.append(stripped[:-1].strip()) else: command_lines.append(stripped) in_command = False # End of command else: out_lines.append(line) if not command_lines: raise ValueError(f"No valid python command found in {bash_file}") command = " ".join(command_lines) # get the partial_T if "partial_T" in command: final_step = int(command.split("partial_T=")[1].split(" ")[0]) - 2 else: final_step = 48 output_command = ( f"{command} inference.deterministic=True inference.final_step={final_step}" ) # replace inference.num_designs with 1 if "inference.num_designs=" in output_command: output_command = f'{output_command.split("inference.num_designs=")[0]}inference.num_designs=1 {" ".join(output_command.split("inference.num_designs=")[1].split(" ")[1:])}' else: output_command = f"{output_command} inference.num_designs=1" # replace 'example_outputs' with f'{self.out_f}/example_outputs' output_command = f'{output_command.split("example_outputs")[0]}{cls.out_f}/example_outputs{output_command.split("example_outputs")[1]}' # write the new command with open(f"{test_f}/{os.path.basename(bash_file)}", "w") as f: for line in out_lines: f.write(line) f.write(output_command) def execute_through_pty(command_line): import pty, select if sys.platform == "darwin": master, slave = pty.openpty() p = subprocess.Popen( command_line, shell=True, stdout=slave, stdin=slave, stderr=subprocess.STDOUT, close_fds=True, ) buffer = [] while True: try: if select.select([master], [], [], 0.2)[0]: # has something to read data = os.read(master, 1 << 22) if data: buffer.append(data) elif (p.poll() is not None) and ( not select.select([master], [], [], 0.2)[0] ): break # process is finished and output buffer if fully read except OSError: break # OSError will be raised when child process close PTY descriptior output = b"".join(buffer).decode(encoding="utf-8", errors="backslashreplace") os.close(master) os.close(slave) p.wait() exit_code = p.returncode """ buffer = [] while True: if select.select([master], [], [], 0.2)[0]: # has something to read data = os.read(master, 1 << 22) if data: buffer.append(data) # else: break # # EOF - well, technically process _should_ be finished here... # elif time.sleep(1) or (p.poll() is not None): # process is finished (sleep here is intentional to trigger race condition, see solution for this on the next few lines) # assert not select.select([master], [], [], 0.2)[0] # should be nothing left to read... # break elif (p.poll() is not None) and (not select.select([master], [], [], 0.2)[0] ): break # process is finished and output buffer if fully read assert not select.select([master], [], [], 0.2)[0] # should be nothing left to read... os.close(slave) os.close(master) output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace') exit_code = p.returncode """ else: master, slave = pty.openpty() p = subprocess.Popen( command_line, shell=True, stdout=slave, stdin=slave, stderr=subprocess.STDOUT, close_fds=True, ) os.close(slave) buffer = [] while True: try: data = os.read(master, 1 << 22) if data: buffer.append(data) except OSError: break # OSError will be raised when child process close PTY descriptior output = b"".join(buffer).decode(encoding="utf-8", errors="backslashreplace") os.close(master) p.wait() exit_code = p.returncode return exit_code, output def execute( message, command_line, return_="status", until_successes=False, terminate_on_failure=True, silent=False, silence_output=False, silence_output_on_errors=False, add_message_and_command_line_to_output=False, ): if not silent: print(message) print(command_line) sys.stdout.flush() while True: # exit_code, output = execute_through_subprocess(command_line) # exit_code, output = execute_through_pexpect(command_line) exit_code, output = execute_through_pty(command_line) if (exit_code and not silence_output_on_errors) or not ( silent or silence_output ): print(output) sys.stdout.flush() if exit_code and until_successes: pass # Thats right - redability COUNT! else: break print("Error while executing {}: {}\n".format(message, output)) print("Sleeping 60s... then I will retry...") sys.stdout.flush() time.sleep(60) if add_message_and_command_line_to_output: output = message + "\nCommand line: " + command_line + "\n" + output if return_ == "tuple" or return_ == tuple: return (exit_code, output) if exit_code and terminate_on_failure: print("\nEncounter error while executing: " + command_line) if return_ == True: return True else: print("\nEncounter error while executing: " + command_line + "\n" + output) raise BenchmarkError( "\nEncounter error while executing: " + command_line + "\n" + output ) if return_ == "output": return output else: return exit_code if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "--total_chunks", type=int, default=1, help="total number of chunks to split the examples into (default: 1)", ) parser.add_argument( "--chunk_index", type=int, default=1, help="Which chunk to run (1-based index, default:1)", ) args, remaining = parser.parse_known_args() TestSubmissionCommands.total_chunks = args.total_chunks TestSubmissionCommands.chunk_index = args.chunk_index unittest.main(argv=[sys.argv[0]] + remaining)