Fix skip existing not working (#140)

This commit is contained in:
Jasper Butcher
2025-12-27 19:57:29 +00:00
committed by GitHub
parent 122f39b7e8
commit 2bfcd663bc

View File

@@ -48,9 +48,8 @@ class RFD3InferenceConfig:
diffusion_batch_size: int = 16
# RFD3 specific
skip_existing: bool = False
json_keys_subset: Optional[List[str]] = None
skip_existing: bool = True
json_keys_subset: Optional[List[str]] = None
specification: Optional[dict] = field(default_factory=dict)
inference_sampler: SampleDiffusionConfig | dict = field(default_factory=dict)
@@ -216,6 +215,9 @@ class RFD3InferenceEngine(BaseInferenceEngine):
inputs=inputs,
n_batches=n_batches,
)
if len(design_specifications) == 0:
ranked_logger.info("No design specifications to run. Skipping.")
return None
ensure_inference_sampler_matches_design_spec(
design_specifications, self.inference_sampler_overrides
)
@@ -381,12 +383,18 @@ class RFD3InferenceEngine(BaseInferenceEngine):
) -> Dict[str, dict | DesignInputSpecification]:
# Find existing example IDS in output directory
if exists(self.out_dir):
existing_example_ids = set(
existing_example_ids_ = set(
extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS)
for path in find_files_with_extension(self.out_dir, CIF_LIKE_EXTENSIONS)
)
existing_example_ids = set(
[
"_model_".join(eid.split("_model_")[:-1])
for eid in existing_example_ids_
]
)
ranked_logger.info(
f"Found {len(existing_example_ids)} existing example IDs in the output directory."
f"Found {len(existing_example_ids)} existing example IDs in the output directory ({len(existing_example_ids_)} total)."
)
# Based on inputs, construct the specifications to loop through
@@ -405,7 +413,6 @@ class RFD3InferenceEngine(BaseInferenceEngine):
for batch_id in range((n_batches) if exists(n_batches) else 1):
# ... Example ID
example_id = f"{prefix}_{batch_id}" if exists(n_batches) else prefix
if (
self.skip_existing
and exists(self.out_dir)