diff --git a/src/modelhub/utils/datasets.py b/src/modelhub/utils/datasets.py index 2edb6c0..43f2326 100755 --- a/src/modelhub/utils/datasets.py +++ b/src/modelhub/utils/datasets.py @@ -189,7 +189,8 @@ def recursively_instantiate_datasets_and_samplers( # ... check that the sum of probabilities of all datasets is 1 assert ( - sum(dataset_info["probability"] for dataset_info in datasets_info) == 1.0 + abs(1 - sum(dataset_info["probability"] for dataset_info in datasets_info)) + < 1e-5 ), "Sum of probabilities must be 1.0" # ... compose the list of datasets into a single dataset