diff --git a/models/rfd3/src/rfd3/engine.py b/models/rfd3/src/rfd3/engine.py index 77cbc6f..7d7bd5b 100644 --- a/models/rfd3/src/rfd3/engine.py +++ b/models/rfd3/src/rfd3/engine.py @@ -48,9 +48,8 @@ class RFD3InferenceConfig: diffusion_batch_size: int = 16 # RFD3 specific - skip_existing: bool = False - json_keys_subset: Optional[List[str]] = None skip_existing: bool = True + json_keys_subset: Optional[List[str]] = None specification: Optional[dict] = field(default_factory=dict) inference_sampler: SampleDiffusionConfig | dict = field(default_factory=dict) @@ -216,6 +215,9 @@ class RFD3InferenceEngine(BaseInferenceEngine): inputs=inputs, n_batches=n_batches, ) + if len(design_specifications) == 0: + ranked_logger.info("No design specifications to run. Skipping.") + return None ensure_inference_sampler_matches_design_spec( design_specifications, self.inference_sampler_overrides ) @@ -381,12 +383,18 @@ class RFD3InferenceEngine(BaseInferenceEngine): ) -> Dict[str, dict | DesignInputSpecification]: # Find existing example IDS in output directory if exists(self.out_dir): - existing_example_ids = set( + existing_example_ids_ = set( extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS) for path in find_files_with_extension(self.out_dir, CIF_LIKE_EXTENSIONS) ) + existing_example_ids = set( + [ + "_model_".join(eid.split("_model_")[:-1]) + for eid in existing_example_ids_ + ] + ) ranked_logger.info( - f"Found {len(existing_example_ids)} existing example IDs in the output directory." + f"Found {len(existing_example_ids)} existing example IDs in the output directory ({len(existing_example_ids_)} total)." ) # Based on inputs, construct the specifications to loop through @@ -405,7 +413,6 @@ class RFD3InferenceEngine(BaseInferenceEngine): for batch_id in range((n_batches) if exists(n_batches) else 1): # ... Example ID example_id = f"{prefix}_{batch_id}" if exists(n_batches) else prefix - if ( self.skip_existing and exists(self.out_dir)