mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
708 lines
20 KiB
Python
708 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
"""Submit AlphaFold3 functional tests to Slurm and summarize results.
|
|
|
|
This is a standalone wrapper for `test/cluster/check_alphafold3_predictions.py`.
|
|
It is intentionally not a pytest test module, despite the filename.
|
|
|
|
Typical usage from a login node:
|
|
|
|
python test/cluster/run_alphafold3_predictions.py
|
|
|
|
Run only selected tests:
|
|
|
|
python test/cluster/run_alphafold3_predictions.py -k chopped
|
|
|
|
Enable the runtime benchmark test as well:
|
|
|
|
python test/cluster/run_alphafold3_predictions.py --include-perf
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
__test__ = False
|
|
|
|
import argparse
|
|
import dataclasses
|
|
import datetime as dt
|
|
import importlib.util
|
|
import inspect
|
|
import json
|
|
import os
|
|
import re
|
|
import shlex
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import unittest
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
from _pytest.mark.expression import Expression
|
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[2]
|
|
DEFAULT_TEST_FILE = REPO_ROOT / "test" / "cluster" / "check_alphafold3_predictions.py"
|
|
DEFAULT_LOG_ROOT = REPO_ROOT / "test_logs"
|
|
|
|
PASS_STATES = {"COMPLETED"}
|
|
FAIL_STATES = {
|
|
"BOOT_FAIL",
|
|
"CANCELLED",
|
|
"DEADLINE",
|
|
"FAILED",
|
|
"NODE_FAIL",
|
|
"OUT_OF_MEMORY",
|
|
"PREEMPTED",
|
|
"REVOKED",
|
|
"TIMEOUT",
|
|
}
|
|
|
|
|
|
@dataclasses.dataclass(slots=True)
|
|
class JobSpec:
|
|
index: int
|
|
nodeid: str
|
|
slug: str
|
|
stdout_path: Path
|
|
stderr_path: Path
|
|
script_path: Path
|
|
rerun_command: str
|
|
job_id: str | None = None
|
|
slurm_state: str | None = None
|
|
exit_code: str | None = None
|
|
outcome: str | None = None
|
|
reason: str | None = None
|
|
|
|
|
|
def _has_cmd(cmd: str) -> bool:
|
|
try:
|
|
subprocess.run(
|
|
[cmd, "--help"],
|
|
stdout=subprocess.DEVNULL,
|
|
stderr=subprocess.DEVNULL,
|
|
check=False,
|
|
)
|
|
return True
|
|
except FileNotFoundError:
|
|
return False
|
|
|
|
|
|
def _run(
|
|
cmd: list[str],
|
|
*,
|
|
cwd: Path = REPO_ROOT,
|
|
check: bool = True,
|
|
) -> subprocess.CompletedProcess[str]:
|
|
return subprocess.run(
|
|
cmd,
|
|
cwd=cwd,
|
|
text=True,
|
|
capture_output=True,
|
|
check=check,
|
|
)
|
|
|
|
|
|
def _normalize_state(state: str | None) -> str | None:
|
|
if not state:
|
|
return None
|
|
return state.split()[0].rstrip("+")
|
|
|
|
|
|
def _slugify(value: str, *, max_len: int = 120) -> str:
|
|
slug = re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("._")
|
|
if not slug:
|
|
slug = "test"
|
|
if len(slug) > max_len:
|
|
slug = slug[:max_len].rstrip("._")
|
|
return slug
|
|
|
|
|
|
def _quote(value: str) -> str:
|
|
return shlex.quote(value)
|
|
|
|
|
|
def _timestamp() -> str:
|
|
return dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
|
|
|
|
def _relative_nodeid_prefix(test_file: Path) -> str:
|
|
return str(test_file.resolve().relative_to(REPO_ROOT))
|
|
|
|
|
|
def _matches_k_expression(nodeid: str, k_expr: str | None) -> bool:
|
|
if not k_expr:
|
|
return True
|
|
expression = Expression.compile(k_expr)
|
|
lowered = nodeid.lower()
|
|
return expression.evaluate(lambda token: token.lower() in lowered)
|
|
|
|
|
|
def _collect_nodeids_from_module_import(test_file: Path, k_expr: str | None) -> list[str]:
|
|
module_name = f"_codex_collect_{test_file.stem}"
|
|
spec = importlib.util.spec_from_file_location(module_name, test_file)
|
|
if spec is None or spec.loader is None:
|
|
raise RuntimeError(f"Failed to create import spec for {test_file}")
|
|
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[module_name] = module
|
|
try:
|
|
spec.loader.exec_module(module)
|
|
finally:
|
|
sys.modules.pop(module_name, None)
|
|
|
|
prefix = _relative_nodeid_prefix(test_file)
|
|
nodeids: list[str] = []
|
|
for _, cls in inspect.getmembers(module, inspect.isclass):
|
|
if cls.__module__ != module.__name__:
|
|
continue
|
|
if not issubclass(cls, unittest.TestCase):
|
|
continue
|
|
if not cls.__name__.startswith("Test"):
|
|
continue
|
|
|
|
for method_name in sorted(name for name in dir(cls) if name.startswith("test")):
|
|
nodeid = f"{prefix}::{cls.__name__}::{method_name}"
|
|
if _matches_k_expression(nodeid, k_expr):
|
|
nodeids.append(nodeid)
|
|
return nodeids
|
|
|
|
|
|
def collect_nodeids(
|
|
*,
|
|
python_executable: str,
|
|
test_file: Path,
|
|
k_expr: str | None,
|
|
) -> list[str]:
|
|
cmd = [
|
|
python_executable,
|
|
"-m",
|
|
"pytest",
|
|
"-o",
|
|
"addopts=-ra --strict-markers",
|
|
"--collect-only",
|
|
"-q",
|
|
str(test_file),
|
|
]
|
|
if k_expr:
|
|
cmd.extend(["-k", k_expr])
|
|
result = _run(cmd, check=False)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(
|
|
"pytest collection failed.\n"
|
|
f"Command: {' '.join(_quote(part) for part in cmd)}\n\n"
|
|
f"STDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
|
|
)
|
|
|
|
nodeids: list[str] = []
|
|
for raw_line in result.stdout.splitlines():
|
|
line = raw_line.strip()
|
|
if not line:
|
|
continue
|
|
if ".py::" not in line:
|
|
continue
|
|
if line.startswith("ERROR ") or line.startswith("SKIPPED "):
|
|
continue
|
|
nodeids.append(line)
|
|
if nodeids:
|
|
return nodeids
|
|
|
|
return _collect_nodeids_from_module_import(test_file, k_expr)
|
|
|
|
|
|
def write_job_script(
|
|
*,
|
|
job: JobSpec,
|
|
python_executable: str,
|
|
use_temp_dir: bool,
|
|
include_perf: bool,
|
|
) -> None:
|
|
pytest_cmd = [
|
|
python_executable,
|
|
"-m",
|
|
"pytest",
|
|
"-o",
|
|
"addopts=-ra --strict-markers",
|
|
"-vv",
|
|
"-s",
|
|
job.nodeid,
|
|
]
|
|
if use_temp_dir:
|
|
pytest_cmd.append("--use-temp-dir")
|
|
|
|
env_lines = [
|
|
"export PYTHONUNBUFFERED=1",
|
|
]
|
|
if include_perf:
|
|
env_lines.append("export AF3_RUN_PERF_TESTS=1")
|
|
|
|
script = "\n".join(
|
|
[
|
|
"#!/bin/bash",
|
|
"set -euo pipefail",
|
|
f"cd {_quote(str(REPO_ROOT))}",
|
|
*env_lines,
|
|
"echo \"[$(date)] Running test node:\"",
|
|
f"echo {_quote(job.nodeid)}",
|
|
"echo \"[$(date)] Host: $(hostname)\"",
|
|
"echo \"[$(date)] Python: $(which python || true)\"",
|
|
" ".join(_quote(part) for part in pytest_cmd),
|
|
"",
|
|
]
|
|
)
|
|
job.script_path.write_text(script, encoding="utf-8")
|
|
job.script_path.chmod(0o755)
|
|
|
|
|
|
def submit_job(job: JobSpec, args: argparse.Namespace) -> str:
|
|
cmd = [
|
|
"sbatch",
|
|
"--parsable",
|
|
"--export=ALL",
|
|
f"--job-name={args.job_name_prefix}_{job.index:03d}",
|
|
f"--chdir={REPO_ROOT}",
|
|
f"--output={job.stdout_path}",
|
|
f"--error={job.stderr_path}",
|
|
f"--time={args.time}",
|
|
"--ntasks=1",
|
|
f"--cpus-per-task={args.cpus_per_task}",
|
|
f"--mem={args.mem}",
|
|
]
|
|
if args.partition:
|
|
cmd.append(f"--partition={args.partition}")
|
|
if args.qos:
|
|
cmd.append(f"--qos={args.qos}")
|
|
if args.constraint:
|
|
cmd.append(f"--constraint={args.constraint}")
|
|
if args.account:
|
|
cmd.append(f"--account={args.account}")
|
|
if args.gres:
|
|
cmd.append(f"--gres={args.gres}")
|
|
for extra_arg in args.extra_sbatch_arg:
|
|
cmd.append(extra_arg)
|
|
cmd.append(str(job.script_path))
|
|
|
|
result = _run(cmd)
|
|
raw_job_id = result.stdout.strip().splitlines()[-1]
|
|
return raw_job_id.split(";", 1)[0]
|
|
|
|
|
|
def active_job_ids(job_ids: Iterable[str]) -> set[str]:
|
|
job_ids = [job_id for job_id in job_ids if job_id]
|
|
if not job_ids:
|
|
return set()
|
|
|
|
result = _run(
|
|
[
|
|
"squeue",
|
|
"-h",
|
|
"-j",
|
|
",".join(job_ids),
|
|
"-o",
|
|
"%A",
|
|
],
|
|
check=False,
|
|
)
|
|
if result.returncode != 0:
|
|
return set()
|
|
return {line.strip() for line in result.stdout.splitlines() if line.strip()}
|
|
|
|
|
|
def query_sacct(job_id: str) -> tuple[str | None, str | None]:
|
|
if not _has_cmd("sacct"):
|
|
return None, None
|
|
|
|
result = _run(
|
|
[
|
|
"sacct",
|
|
"-X",
|
|
"-n",
|
|
"-P",
|
|
"-j",
|
|
job_id,
|
|
"-o",
|
|
"JobIDRaw,State,ExitCode",
|
|
],
|
|
check=False,
|
|
)
|
|
if result.returncode != 0:
|
|
return None, None
|
|
|
|
for line in result.stdout.splitlines():
|
|
parts = line.strip().split("|")
|
|
if len(parts) < 3:
|
|
continue
|
|
job_id_raw, state, exit_code = parts[:3]
|
|
if job_id_raw == job_id:
|
|
return _normalize_state(state), exit_code
|
|
return None, None
|
|
|
|
|
|
def wait_for_jobs(jobs: list[JobSpec], *, poll_interval: int, timeout_seconds: int | None) -> None:
|
|
outstanding = {job.job_id for job in jobs if job.job_id}
|
|
start = time.monotonic()
|
|
previous_remaining = len(outstanding)
|
|
|
|
while outstanding:
|
|
if timeout_seconds is not None and (time.monotonic() - start) > timeout_seconds:
|
|
raise TimeoutError(
|
|
f"Timed out waiting for {len(outstanding)} Slurm job(s): "
|
|
+ ", ".join(sorted(outstanding))
|
|
)
|
|
|
|
active = active_job_ids(outstanding)
|
|
finished = outstanding - active
|
|
if finished:
|
|
outstanding = active
|
|
|
|
remaining = len(outstanding)
|
|
if remaining != previous_remaining or finished:
|
|
done = len(jobs) - remaining
|
|
print(f"[wait] {done}/{len(jobs)} jobs finished, {remaining} remaining", flush=True)
|
|
previous_remaining = remaining
|
|
|
|
if outstanding:
|
|
time.sleep(poll_interval)
|
|
|
|
|
|
def _combined_log_text(job: JobSpec) -> str:
|
|
parts: list[str] = []
|
|
if job.stdout_path.exists():
|
|
parts.append(job.stdout_path.read_text(encoding="utf-8", errors="replace"))
|
|
if job.stderr_path.exists():
|
|
stderr_text = job.stderr_path.read_text(encoding="utf-8", errors="replace")
|
|
if stderr_text:
|
|
parts.append(stderr_text)
|
|
return "\n".join(parts)
|
|
|
|
|
|
def _extract_reason_from_log(text: str) -> str:
|
|
patterns = [
|
|
r"short test summary info[\s\S]*$",
|
|
r"=+ FAILURES =+[\s\S]*$",
|
|
r"Traceback[\s\S]*$",
|
|
r"(?m)^E\s+.*$",
|
|
r"(?m)^FAILED .*$",
|
|
r"(?m)^ERROR .*$",
|
|
r"(?m)^.*Killed.*$",
|
|
r"(?m)^.*PASSED.*$",
|
|
r"(?m)^.*SKIPPED.*$",
|
|
]
|
|
for pattern in patterns:
|
|
match = re.search(pattern, text)
|
|
if match:
|
|
snippet = match.group(0).strip()
|
|
if len(snippet) > 1200:
|
|
snippet = snippet[-1200:]
|
|
return snippet
|
|
|
|
non_empty_lines = [line.rstrip() for line in text.splitlines() if line.strip()]
|
|
if not non_empty_lines:
|
|
return "No log output captured."
|
|
return "\n".join(non_empty_lines[-20:])
|
|
|
|
|
|
def classify_job(job: JobSpec) -> None:
|
|
job.slurm_state, job.exit_code = query_sacct(job.job_id or "")
|
|
text = _combined_log_text(job)
|
|
state = job.slurm_state
|
|
|
|
if state in FAIL_STATES:
|
|
job.outcome = "FAILED"
|
|
job.reason = f"Slurm state: {state}\n{_extract_reason_from_log(text)}"
|
|
return
|
|
|
|
if re.search(r"(?im)\bkilled\b", text):
|
|
job.outcome = "FAILED"
|
|
job.reason = _extract_reason_from_log(text)
|
|
return
|
|
|
|
if re.search(r"(?m)^FAILED ", text) or re.search(r"(?m)^ERROR ", text):
|
|
job.outcome = "FAILED"
|
|
job.reason = _extract_reason_from_log(text)
|
|
return
|
|
|
|
if re.search(r"=+ FAILURES =+", text) or "Traceback" in text:
|
|
job.outcome = "FAILED"
|
|
job.reason = _extract_reason_from_log(text)
|
|
return
|
|
|
|
if re.search(r"(?i)\b\d+\s+skipped\b", text) or " SKIPPED" in text:
|
|
job.outcome = "SKIPPED"
|
|
job.reason = _extract_reason_from_log(text)
|
|
return
|
|
|
|
if re.search(r"(?i)\b\d+\s+passed\b", text) or " PASSED" in text:
|
|
job.outcome = "PASSED"
|
|
job.reason = _extract_reason_from_log(text)
|
|
return
|
|
|
|
if state in PASS_STATES:
|
|
job.outcome = "PASSED"
|
|
job.reason = _extract_reason_from_log(text)
|
|
return
|
|
|
|
job.outcome = "UNKNOWN"
|
|
job.reason = _extract_reason_from_log(text)
|
|
|
|
|
|
def write_summary(log_dir: Path, jobs: list[JobSpec]) -> Path:
|
|
payload = {
|
|
"generated_at": dt.datetime.now().isoformat(),
|
|
"repo_root": str(REPO_ROOT),
|
|
"jobs": [
|
|
{
|
|
"index": job.index,
|
|
"nodeid": job.nodeid,
|
|
"job_id": job.job_id,
|
|
"slurm_state": job.slurm_state,
|
|
"exit_code": job.exit_code,
|
|
"outcome": job.outcome,
|
|
"stdout_log": str(job.stdout_path),
|
|
"stderr_log": str(job.stderr_path),
|
|
"rerun_command": job.rerun_command,
|
|
"reason": job.reason,
|
|
}
|
|
for job in jobs
|
|
],
|
|
}
|
|
summary_path = log_dir / "summary.json"
|
|
summary_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
|
return summary_path
|
|
|
|
|
|
def print_summary(jobs: list[JobSpec], summary_path: Path) -> int:
|
|
counts: dict[str, int] = {}
|
|
for job in jobs:
|
|
counts[job.outcome or "UNKNOWN"] = counts.get(job.outcome or "UNKNOWN", 0) + 1
|
|
|
|
print("\nSummary")
|
|
for outcome in sorted(counts):
|
|
print(f" {outcome}: {counts[outcome]}")
|
|
print(f" summary_json: {summary_path}")
|
|
|
|
problem_jobs = [job for job in jobs if job.outcome not in {"PASSED", "SKIPPED"}]
|
|
if problem_jobs:
|
|
print("\nProblems")
|
|
for job in problem_jobs:
|
|
print(f" {job.nodeid}")
|
|
print(f" slurm_job: {job.job_id}")
|
|
print(f" state: {job.slurm_state or 'unknown'}")
|
|
print(f" stdout: {job.stdout_path}")
|
|
print(f" stderr: {job.stderr_path}")
|
|
print(f" rerun: {job.rerun_command}")
|
|
if job.reason:
|
|
for line in job.reason.splitlines()[:20]:
|
|
print(f" {line}")
|
|
return 1 if problem_jobs else 0
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description=(
|
|
"Submit AlphaFold3 functional tests to Slurm in parallel, wait for completion, "
|
|
"and summarize the logs."
|
|
)
|
|
)
|
|
parser.add_argument(
|
|
"nodeid",
|
|
nargs="*",
|
|
help=(
|
|
"Optional exact pytest node IDs to submit. If omitted, tests are collected "
|
|
f"from {DEFAULT_TEST_FILE.relative_to(REPO_ROOT)}."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--test-file",
|
|
default=str(DEFAULT_TEST_FILE),
|
|
help="Pytest file to collect from. Defaults to test/cluster/check_alphafold3_predictions.py",
|
|
)
|
|
parser.add_argument(
|
|
"-k",
|
|
dest="k_expr",
|
|
default=None,
|
|
help="Optional pytest -k expression applied during collection.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-tests",
|
|
type=int,
|
|
default=None,
|
|
help="Submit at most this many collected tests.",
|
|
)
|
|
parser.add_argument(
|
|
"--list",
|
|
action="store_true",
|
|
help="List collected node IDs and exit without submitting jobs.",
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Collect tests and write job scripts, but do not call sbatch.",
|
|
)
|
|
parser.add_argument(
|
|
"--include-perf",
|
|
action="store_true",
|
|
help="Set AF3_RUN_PERF_TESTS=1 inside jobs so the runtime benchmark is included.",
|
|
)
|
|
parser.add_argument(
|
|
"--use-temp-dir",
|
|
action=argparse.BooleanOptionalAction,
|
|
default=True,
|
|
help=(
|
|
"Run target tests with isolated temporary output directories. "
|
|
"Use --no-use-temp-dir to keep the shared repo output tree."
|
|
),
|
|
)
|
|
parser.add_argument("--partition", default="gpu-el8", help="Slurm partition/queue.")
|
|
parser.add_argument("--qos", default="normal", help="Slurm QoS.")
|
|
parser.add_argument("--constraint", default="gaming", help="Optional Slurm constraint.")
|
|
parser.add_argument("--account", default=None, help="Optional Slurm account.")
|
|
parser.add_argument("--gres", default="gpu:1", help="Slurm gres request, for example gpu:1.")
|
|
parser.add_argument("--time", default="12:00:00", help="Per-job walltime.")
|
|
parser.add_argument("--cpus-per-task", type=int, default=8, help="CPUs per Slurm task.")
|
|
parser.add_argument("--mem", default="16G", help="Per-job memory request.")
|
|
parser.add_argument(
|
|
"--extra-sbatch-arg",
|
|
action="append",
|
|
default=[],
|
|
help="Additional raw sbatch argument. Can be passed multiple times.",
|
|
)
|
|
parser.add_argument(
|
|
"--job-name-prefix",
|
|
default="af3test",
|
|
help="Prefix for Slurm job names.",
|
|
)
|
|
parser.add_argument(
|
|
"--poll-interval",
|
|
type=int,
|
|
default=30,
|
|
help="Seconds between Slurm polling cycles.",
|
|
)
|
|
parser.add_argument(
|
|
"--wait-timeout-hours",
|
|
type=float,
|
|
default=24.0,
|
|
help="Maximum hours to wait for all submitted jobs. Use 0 to disable.",
|
|
)
|
|
parser.add_argument(
|
|
"--log-dir",
|
|
default=None,
|
|
help="Directory to write job scripts and logs into. Defaults to test_logs/alphafold3_<timestamp>.",
|
|
)
|
|
parser.add_argument(
|
|
"--python",
|
|
default=sys.executable,
|
|
help="Python executable used both for collection and inside Slurm jobs.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
if not _has_cmd("sbatch") and not args.list and not args.dry_run:
|
|
raise SystemExit("sbatch is not available in PATH.")
|
|
if not _has_cmd("squeue") and not args.list and not args.dry_run:
|
|
raise SystemExit("squeue is not available in PATH.")
|
|
|
|
test_file = Path(args.test_file).resolve()
|
|
if not test_file.exists():
|
|
raise SystemExit(f"Test file does not exist: {test_file}")
|
|
|
|
if args.log_dir:
|
|
log_dir = Path(args.log_dir).resolve()
|
|
else:
|
|
log_dir = (DEFAULT_LOG_ROOT / f"alphafold3_{_timestamp()}").resolve()
|
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if args.nodeid:
|
|
nodeids = list(args.nodeid)
|
|
else:
|
|
nodeids = collect_nodeids(
|
|
python_executable=args.python,
|
|
test_file=test_file,
|
|
k_expr=args.k_expr,
|
|
)
|
|
|
|
if args.max_tests is not None:
|
|
nodeids = nodeids[: args.max_tests]
|
|
|
|
if not nodeids:
|
|
print("No tests matched the requested selection.")
|
|
return 0
|
|
|
|
if not args.use_temp_dir and len(nodeids) > 1:
|
|
raise SystemExit(
|
|
"--no-use-temp-dir is not safe for parallel AF3 wrapper runs because "
|
|
"the tests share and clean common output roots. Re-run with the default "
|
|
"--use-temp-dir, or submit a single nodeid at a time."
|
|
)
|
|
|
|
if args.list:
|
|
for nodeid in nodeids:
|
|
print(nodeid)
|
|
return 0
|
|
|
|
print(f"Collected {len(nodeids)} test node(s).")
|
|
print(f"Log directory: {log_dir}")
|
|
|
|
jobs: list[JobSpec] = []
|
|
for index, nodeid in enumerate(nodeids, start=1):
|
|
slug = _slugify(nodeid)
|
|
stdout_path = log_dir / f"{index:03d}_{slug}.out"
|
|
stderr_path = log_dir / f"{index:03d}_{slug}.err"
|
|
script_path = log_dir / f"{index:03d}_{slug}.sbatch.sh"
|
|
rerun_command = (
|
|
f"{_quote(args.python)} -m pytest -vv -s {_quote(nodeid)}"
|
|
+ (" --use-temp-dir" if args.use_temp_dir else "")
|
|
)
|
|
job = JobSpec(
|
|
index=index,
|
|
nodeid=nodeid,
|
|
slug=slug,
|
|
stdout_path=stdout_path,
|
|
stderr_path=stderr_path,
|
|
script_path=script_path,
|
|
rerun_command=rerun_command,
|
|
)
|
|
write_job_script(
|
|
job=job,
|
|
python_executable=args.python,
|
|
use_temp_dir=args.use_temp_dir,
|
|
include_perf=args.include_perf,
|
|
)
|
|
jobs.append(job)
|
|
|
|
if args.dry_run:
|
|
print("Dry run only. Prepared job scripts:")
|
|
for job in jobs:
|
|
print(f" {job.nodeid}")
|
|
print(f" script: {job.script_path}")
|
|
print(f" stdout: {job.stdout_path}")
|
|
print(f" stderr: {job.stderr_path}")
|
|
return 0
|
|
|
|
for job in jobs:
|
|
job.job_id = submit_job(job, args)
|
|
print(f"[submit] {job.job_id} {job.nodeid}")
|
|
|
|
timeout_seconds: int | None
|
|
if args.wait_timeout_hours <= 0:
|
|
timeout_seconds = None
|
|
else:
|
|
timeout_seconds = int(args.wait_timeout_hours * 3600)
|
|
|
|
wait_for_jobs(
|
|
jobs,
|
|
poll_interval=args.poll_interval,
|
|
timeout_seconds=timeout_seconds,
|
|
)
|
|
|
|
for job in jobs:
|
|
classify_job(job)
|
|
|
|
summary_path = write_summary(log_dir, jobs)
|
|
return print_summary(jobs, summary_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|