diff --git a/src/foqa/dataset.py b/src/foqa/dataset.py index df6a303..9a60cfc 100644 --- a/src/foqa/dataset.py +++ b/src/foqa/dataset.py @@ -31,11 +31,6 @@ def build_dataset(config: DictConfig) -> None: ) assert isinstance(dataset, Dataset) - num_samples = min(config.num_samples, len(dataset)) - if num_samples < config.num_samples: - logger.info(f"Reduced number of samples to the maximal {num_samples:,}.") - dataset = dataset.select(range(num_samples)) - records_path = Path(config.dirs.data) / config.dirs.raw / "records.jsonl" if records_path.exists(): with records_path.open() as f: @@ -45,6 +40,13 @@ def build_dataset(config: DictConfig) -> None: with tqdm(dataset, desc=f"Generating samples with {config.model}") as pbar: for sample in pbar: + if len(records) >= config.num_samples: + logger.info( + f"Reached the target number of samples ({config.num_samples:,}). " + "Stopping." + ) + break + sample_exists = any(record["id"] == sample["url"] for record in records) if sample_exists: continue