Fix workflow tests that are failing (#410)

This PR updates the tests so that all the examples run and if an example
fails then the test results in a failure as well. Changes include:

- Reformatting design_macrocyclic_binder.sh and
design_macrocyclic_monomer.sh to be submitted correctly by
test_diffusion.py
- Reducing the total length in design_tetrahedral_oligos.sh to reduce
run time of this test
- Changes to test_diffusion.py and main.yml to be able to run the
examples in different chunks so examples can run in parallel and to make
sure that if an example errors out, that the tests does not pass.

Currently design_ppi_scaffolded, design_timbarrel, and
design_ppi_flexible_peptide_with_secondarystructure_specification are
failing which should be addressed in other, future PRs.
This commit is contained in:
Hope Woods
2025-11-13 15:33:05 -06:00
committed by GitHub
5 changed files with 274 additions and 101 deletions

View File

@@ -69,6 +69,13 @@ jobs:
uv pip install --no-cache-dir -e . --no-deps
rm -rf ~/.cache # /app/RFdiffusion/tests
- name: Preseed DGL backend
shell: bash
run: |
mkdir -p "$HOME/.dgl"
printf '{"backend": "pytorch"}' > "$HOME/.dgl/config.conf"
echo "DGLBACKEND=pytorch" >> "$GITHUB_ENV"
- name: Download weights
run: |
mkdir models
@@ -87,8 +94,29 @@ jobs:
- name: Setup and Run ppi_scaffolds tests
run: |
tar -xvf examples/ppi_scaffolds_subset.tar.gz -C examples
cd tests && uv run python test_diffusion.py
total_chunks=$(nproc)
cd tests
#launch all chunks in background and record PIDs + labels
pids=""
for chunk_index in $(seq 1 $total_chunks); do
echo "Running chunk $chunk_index of $total_chunks"
uv run python test_diffusion.py --total_chunks $total_chunks --chunk_index $chunk_index &
pids="$pids $!"
done
# wait for each and track failures
fail=0
for pid in $pids; do
if ! wait "$pid"; then
echo "A chunk (PID $pid) failed"
fail=1
else
echo "A chunk (PID $pid) passed"
fi
done
exit "$fail"
# - name: Test with pytest
# run: |

View File

@@ -1,18 +1,15 @@
#!/bin/bash
prefix=./outputs/diffused_binder_cyclic2
# Note that in the example below the indices in the
# input_pdbs/7zkr_GABARAP.pdb file have been shifted
# by +2 in chain A relative to pdbID 7zkr.
# Note that the indices in this pdb file have been
# shifted by +2 in chain A relative to pdbID 7zkr.
pdb='./input_pdbs/7zkr_GABARAP.pdb'
num_designs=10
script="../scripts/run_inference.py"
$script --config-name base \
inference.output_prefix=$prefix \
inference.num_designs=$num_designs \
../scripts/run_inference.py \
--config-name base \
inference.output_prefix=example_outputs/diffused_binder_cyclic2 \
inference.num_designs=10 \
'contigmap.contigs=[12-18 A3-117/0]' \
inference.input_pdb=$pdb \
inference.input_pdb=./input_pdbs/7zkr_GABARAP.pdb \
inference.cyclic=True \
diffuser.T=50 \
inference.cyc_chains='a' \

View File

@@ -1,17 +1,15 @@
#!/bin/bash
prefix=./outputs/uncond_cycpep
# Note that the indices in this pdb file have been
# shifted by +2 in chain A relative to pdbID 7zkr.
pdb='./input_pdbs/7zkr_GABARAP.pdb'
# Note that in the example below the indices in the
# input_pdbs/7zkr_GABARAP.pdb file have been shifted
# by +2 in chain A relative to pdbID 7zkr.
num_designs=10
script="../scripts/run_inference.py"
$script --config-name base \
inference.output_prefix=$prefix \
inference.num_designs=$num_designs \
../scripts/run_inference.py \
--config-name base \
inference.output_prefix=example_outputs/uncond_cycpep \
inference.num_designs=10 \
'contigmap.contigs=[12-18]' \
inference.input_pdb=$pdb \
inference.input_pdb=input_pdbs/7zkr_GABARAP.pdb \
inference.cyclic=True \
diffuser.T=50 \
inference.cyc_chains='a'

View File

@@ -5,6 +5,6 @@
# This external potential promotes contacts both within (with a relative weight of 1) and between chains (relative weight 0.1)
# We specify that we want to apply these potentials to all chains, with a guide scale of 2.0 (a sensible starting point)
# We decay this potential with quadratic form, so that it is applied more strongly initially
# We specify a total length of 1200aa, so each chain is 100 residues long
# We specify a total length of 1200aa, so each chain is 100 residues long - length updated to 600aa, so each chain is 50 residues long for testing to run faster
python ../scripts/run_inference.py --config-name=symmetry inference.symmetry="tetrahedral" inference.num_designs=10 inference.output_prefix="example_outputs/tetrahedral_oligo" 'potentials.guiding_potentials=["type:olig_contacts,weight_intra:1,weight_inter:0.1"]' potentials.olig_intra_all=True potentials.olig_inter_all=True potentials.guide_scale=2.0 potentials.guide_decay="quadratic" 'contigmap.contigs=[1200-1200]'
python ../scripts/run_inference.py --config-name=symmetry inference.symmetry="tetrahedral" inference.num_designs=10 inference.output_prefix="example_outputs/tetrahedral_oligo" 'potentials.guiding_potentials=["type:olig_contacts,weight_intra:1,weight_inter:0.1"]' potentials.olig_intra_all=True potentials.olig_inter_all=True potentials.guide_scale=2.0 potentials.guide_decay="quadratic" 'contigmap.contigs=[600-600]'

View File

@@ -11,6 +11,7 @@ import sys, json
script_dir = os.path.dirname(os.path.abspath(__file__))
class TestSubmissionCommands(unittest.TestCase):
"""
Test harness for checking that commands in the examples folder,
@@ -25,88 +26,163 @@ class TestSubmissionCommands(unittest.TestCase):
outputs are the same as the reference outputs.
"""
def setUp(self):
failed_tests = []
# number of chunks to split examples into
total_chunks = 1
# which chunk to run
chunk_index = 1
out_f = None
results = {}
exec_status = {}
@classmethod
def setUpClass(cls):
"""
Grabs files from the examples folder
Class-level setup: Grabs files from the examples folder, discover & rewrite example commands, then execute them once.
"""
submissions = glob.glob(f"{script_dir}/../examples/*.sh")
# get datetime for output folder, in YYYY_MM_DD_HH_MM_SS format
chunks = cls.total_chunks
idx = cls.chunk_index
if chunks < 1:
raise ValueError("total_chunks must be at least 1")
if idx < 1 or idx > chunks:
raise ValueError(
"chunk_index must be between 1 and total_chunks (inclusive)"
)
if chunks > 1:
submissions = [
submissions[i]
for i in range(len(submissions))
if i % chunks == (idx - 1)
]
print(
f"Running chunk {idx}/{chunks}, {len(submissions)} submissions to run"
)
if not submissions:
raise ValueError("No submissions selected for chunk {idx} of {chunks}")
now = datetime.datetime.now()
now = now.strftime("%Y_%m_%d_%H_%M_%S")
self.out_f = f"{script_dir}/tests_{now}"
os.mkdir(self.out_f)
cls.out_f = f"{script_dir}/tests_{now}_{idx}"
os.mkdir(cls.out_f)
# Make sure we have access to all the relevant files
exclude_dirs = ["outputs", "example_outputs"]
for filename in os.listdir(f"{script_dir}/../examples"):
if filename not in exclude_dirs and not os.path.islink(os.path.join(script_dir, filename)) and os.path.isdir(os.path.join(f'{script_dir}/../examples', filename)):
os.symlink(os.path.join(f'{script_dir}/../examples', filename), os.path.join(script_dir, filename))
if (
filename not in exclude_dirs
and not os.path.exists(os.path.join(script_dir, filename))
and os.path.isdir(os.path.join(f"{script_dir}/../examples", filename))
):
try:
os.symlink(
os.path.join(f"{script_dir}/../examples", filename),
os.path.join(script_dir, filename),
)
except FileExistsError:
pass
for submission in submissions:
self._write_command(submission, self.out_f)
cls._write_command(submission, cls.out_f)
print(f"Running commands in {self.out_f}, two steps of diffusion, deterministic=True")
print(
f"Running commands in {cls.out_f}, two steps of diffusion, deterministic=True"
)
self.results = {}
cls.results = {}
cls.exec_status = {}
for bash_file in sorted( glob.glob(f"{self.out_f}/*.sh"), reverse=False):
test_name = os.path.basename(bash_file)[:-len('.sh')]
res, output = execute(f"Running {test_name}", f'bash {bash_file}', return_='tuple', add_message_and_command_line_to_output=True)
for bash_file in sorted(glob.glob(f"{cls.out_f}/*.sh"), reverse=False):
test_name = os.path.basename(bash_file)[: -len(".sh")]
res, output = execute(
f"Running {test_name}",
f"bash {bash_file}",
return_="tuple",
add_message_and_command_line_to_output=True,
)
cls.exec_status[test_name] = (res, output)
self.results[test_name] = dict(
state = 'failed' if res else 'passed',
log = output,
cls.results[test_name] = dict(
state="failed" if res else "passed",
log=output,
)
#subprocess.run(["bash", bash_file], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
#subprocess.run(["bash", bash_file])
# subprocess.run(["bash", bash_file], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
# subprocess.run(["bash", bash_file])
def test_examples_run_without_errors(self):
for name, (exit_code, output) in sorted(self.__class__.exec_status.items()):
with self.subTest(example=name):
if exit_code != 0:
self.__class__.failed_tests.append(f"{name}")
self.assertEqual(
exit_code,
0,
msg=f"Example '{name}' exited with {exit_code}\n{output}",
)
sys.stderr.write("\n==== EXAMPLE FAILURE SUMMARY ====\n")
for line in self.__class__.failed_tests:
sys.stderr.write(f" - {line}\n")
sys.stderr.write("=========================\n\n")
sys.stderr.flush()
def test_commands(self):
"""
Runs all the commands in the test_f folder
"""
reference=f'{script_dir}/reference_outputs'
reference = f"{script_dir}/reference_outputs"
os.makedirs(reference, exist_ok=True)
test_files=glob.glob(f"{self.out_f}/example_outputs/*pdb")
print(f'{self.out_f=} {test_files=}')
test_files = glob.glob(f"{self.__class__.out_f}/example_outputs/*pdb")
print(f"{self.__class__.out_f=} {test_files=}")
# first check that we have the right number of outputs
#self.assertEqual(len(test_files), len(glob.glob(f"{self.out_f}/*.sh"))), "One or more of the example commands didn't produce an output (check the example command is formatted correctly)"
# self.assertEqual(len(test_files), len(glob.glob(f"{self.out_f}/*.sh"))), "One or more of the example commands didn't produce an output (check the example command is formatted correctly)"
result = self.defaultTestResult()
for test_file in test_files:
with self.subTest(test_file=test_file):
test_pdb=iu.parse_pdb(test_file)
test_pdb = iu.parse_pdb(test_file)
if not os.path.exists(f"{reference}/{os.path.basename(test_file)}"):
copyfile(test_file, f"{reference}/{os.path.basename(test_file)}")
print(f"Created reference file {reference}/{os.path.basename(test_file)}")
print(
f"Created reference file {reference}/{os.path.basename(test_file)}"
)
else:
ref_pdb=iu.parse_pdb(f"{reference}/{os.path.basename(test_file)}")
rmsd=calc_rmsd(test_pdb['xyz'][:,:3].reshape(-1,3), ref_pdb['xyz'][:,:3].reshape(-1,3))[0]
ref_pdb = iu.parse_pdb(f"{reference}/{os.path.basename(test_file)}")
rmsd = calc_rmsd(
test_pdb["xyz"][:, :3].reshape(-1, 3),
ref_pdb["xyz"][:, :3].reshape(-1, 3),
)[0]
try:
self.assertAlmostEqual(rmsd, 0, 2)
result.addSuccess(self)
print(f"Subtest {test_file} passed")
state = 'passed'
log = f'Subtest {test_file} passed'
state = "passed"
log = f"Subtest {test_file} passed"
except AssertionError as e:
result.addFailure(self, e)
print(f"Subtest {test_file} failed")
state = 'failed'
log = f'Subtest {test_file} failed:\n{e!r}'
state = "failed"
log = f"Subtest {test_file} failed:\n{e!r}"
self.results[ 'pdb-diff.' + test_file.rpartition('/')[-1] ] = dict(state = state, log = log)
self.results["pdb-diff." + test_file.rpartition("/")[-1]] = dict(
state=state, log=log
)
with open('.results.json', 'w') as f: json.dump(self.results, f, sort_keys=True, indent=2)
with open(".results.json", "w") as f:
json.dump(self.results, f, sort_keys=True, indent=2)
self.assertTrue(result.wasSuccessful(), "One or more subtests failed")
def _write_command(self, bash_file, test_f) -> None:
@classmethod
def _write_command(cls, bash_file, test_f) -> None:
"""
Takes a bash file from the examples folder, and writes
a version of it to the test_f folder.
@@ -117,31 +193,42 @@ class TestSubmissionCommands(unittest.TestCase):
else:
inference.final_step=48
"""
out_lines=[]
out_lines = []
command_lines = []
in_command = False
with open(bash_file, "r") as f:
lines = f.readlines()
for line in lines:
if not (line.startswith("python") or line.startswith("../")):
out_lines.append(line)
for line in f:
stripped = line.strip()
if stripped.startswith("python") or stripped.startswith("../"):
in_command = True
if in_command:
# Remove trailing line continuation slashes
if stripped.endswith("\\"):
command_lines.append(stripped[:-1].strip())
else:
command_lines.append(stripped)
in_command = False # End of command
else:
command = line.strip()
if not command.startswith("python"):
command = f'python {command}'
out_lines.append(line)
if not command_lines:
raise ValueError(f"No valid python command found in {bash_file}")
command = " ".join(command_lines)
# get the partial_T
if "partial_T" in command:
final_step = int(command.split("partial_T=")[1].split(" ")[0]) - 2
else:
final_step = 48
output_command = f"{command} inference.deterministic=True inference.final_step={final_step}"
output_command = (
f"{command} inference.deterministic=True inference.final_step={final_step}"
)
# replace inference.num_designs with 1
if "inference.num_designs=" in output_command:
output_command = f'{output_command.split("inference.num_designs=")[0]}inference.num_designs=1 {" ".join(output_command.split("inference.num_designs=")[1].split(" ")[1:])}'
else:
output_command = f'{output_command} inference.num_designs=1'
output_command = f"{output_command} inference.num_designs=1"
# replace 'example_outputs' with f'{self.out_f}/example_outputs'
output_command = f'{output_command.split("example_outputs")[0]}{self.out_f}/example_outputs{output_command.split("example_outputs")[1]}'
output_command = f'{output_command.split("example_outputs")[0]}{cls.out_f}/example_outputs{output_command.split("example_outputs")[1]}'
# write the new command
with open(f"{test_f}/{os.path.basename(bash_file)}", "w") as f:
@@ -150,28 +237,38 @@ class TestSubmissionCommands(unittest.TestCase):
f.write(output_command)
def execute_through_pty(command_line):
import pty, select
if sys.platform == "darwin":
master, slave = pty.openpty()
p = subprocess.Popen(command_line, shell=True, stdout=slave, stdin=slave,
stderr=subprocess.STDOUT, close_fds=True)
p = subprocess.Popen(
command_line,
shell=True,
stdout=slave,
stdin=slave,
stderr=subprocess.STDOUT,
close_fds=True,
)
buffer = []
while True:
try:
if select.select([master], [], [], 0.2)[0]: # has something to read
data = os.read(master, 1 << 22)
if data: buffer.append(data)
if data:
buffer.append(data)
elif (p.poll() is not None) and (not select.select([master], [], [], 0.2)[0] ): break # process is finished and output buffer if fully read
elif (p.poll() is not None) and (
not select.select([master], [], [], 0.2)[0]
):
break # process is finished and output buffer if fully read
except OSError: break # OSError will be raised when child process close PTY descriptior
except OSError:
break # OSError will be raised when child process close PTY descriptior
output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace')
output = b"".join(buffer).decode(encoding="utf-8", errors="backslashreplace")
os.close(master)
os.close(slave)
@@ -179,7 +276,7 @@ def execute_through_pty(command_line):
p.wait()
exit_code = p.returncode
'''
"""
buffer = []
while True:
if select.select([master], [], [], 0.2)[0]: # has something to read
@@ -200,13 +297,19 @@ def execute_through_pty(command_line):
output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace')
exit_code = p.returncode
'''
"""
else:
master, slave = pty.openpty()
p = subprocess.Popen(command_line, shell=True, stdout=slave, stdin=slave,
stderr=subprocess.STDOUT, close_fds=True)
p = subprocess.Popen(
command_line,
shell=True,
stdout=slave,
stdin=slave,
stderr=subprocess.STDOUT,
close_fds=True,
)
os.close(slave)
@@ -214,10 +317,12 @@ def execute_through_pty(command_line):
while True:
try:
data = os.read(master, 1 << 22)
if data: buffer.append(data)
except OSError: break # OSError will be raised when child process close PTY descriptior
if data:
buffer.append(data)
except OSError:
break # OSError will be raised when child process close PTY descriptior
output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace')
output = b"".join(buffer).decode(encoding="utf-8", errors="backslashreplace")
os.close(master)
@@ -227,39 +332,84 @@ def execute_through_pty(command_line):
return exit_code, output
def execute(message, command_line, return_='status', until_successes=False, terminate_on_failure=True, silent=False, silence_output=False, silence_output_on_errors=False, add_message_and_command_line_to_output=False):
if not silent: print(message); print(command_line); sys.stdout.flush();
def execute(
message,
command_line,
return_="status",
until_successes=False,
terminate_on_failure=True,
silent=False,
silence_output=False,
silence_output_on_errors=False,
add_message_and_command_line_to_output=False,
):
if not silent:
print(message)
print(command_line)
sys.stdout.flush()
while True:
#exit_code, output = execute_through_subprocess(command_line)
#exit_code, output = execute_through_pexpect(command_line)
# exit_code, output = execute_through_subprocess(command_line)
# exit_code, output = execute_through_pexpect(command_line)
exit_code, output = execute_through_pty(command_line)
if (exit_code and not silence_output_on_errors) or not (silent or silence_output): print(output); sys.stdout.flush();
if (exit_code and not silence_output_on_errors) or not (
silent or silence_output
):
print(output)
sys.stdout.flush()
if exit_code and until_successes: pass # Thats right - redability COUNT!
else: break
if exit_code and until_successes:
pass # Thats right - redability COUNT!
else:
break
print( "Error while executing {}: {}\n".format(message, output) )
print("Error while executing {}: {}\n".format(message, output))
print("Sleeping 60s... then I will retry...")
sys.stdout.flush();
sys.stdout.flush()
time.sleep(60)
if add_message_and_command_line_to_output: output = message + '\nCommand line: ' + command_line + '\n' + output
if add_message_and_command_line_to_output:
output = message + "\nCommand line: " + command_line + "\n" + output
if return_ == 'tuple' or return_ == tuple: return(exit_code, output)
if return_ == "tuple" or return_ == tuple:
return (exit_code, output)
if exit_code and terminate_on_failure:
print("\nEncounter error while executing: " + command_line)
if return_==True: return True
if return_ == True:
return True
else:
print('\nEncounter error while executing: ' + command_line + '\n' + output);
raise BenchmarkError('\nEncounter error while executing: ' + command_line + '\n' + output)
print("\nEncounter error while executing: " + command_line + "\n" + output)
raise BenchmarkError(
"\nEncounter error while executing: " + command_line + "\n" + output
)
if return_ == 'output': return output
else: return exit_code
if return_ == "output":
return output
else:
return exit_code
if __name__ == "__main__":
unittest.main()
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--total_chunks",
type=int,
default=1,
help="total number of chunks to split the examples into (default: 1)",
)
parser.add_argument(
"--chunk_index",
type=int,
default=1,
help="Which chunk to run (1-based index, default:1)",
)
args, remaining = parser.parse_known_args()
TestSubmissionCommands.total_chunks = args.total_chunks
TestSubmissionCommands.chunk_index = args.chunk_index
unittest.main(argv=[sys.argv[0]] + remaining)