diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index c12989d716..9f6d5409ad 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -232,14 +232,25 @@ def log_dataset_uri(cfg: DictConfig) -> mlflow.data.meta_dataset.MetaDataset: """ # Figure out which data source to use data_paths = parse_source_dataset(cfg) - print('Detected Datapaths: ', data_paths) + + # To be used when MLFlow implements fixes + 2.11.4 + # dataset_source_mapping = { + # 's3': mlflow.data.http_dataset_source.HTTPDatasetSource, + # 'oci': mlflow.data.http_dataset_source.HTTPDatasetSource, + # 'https': mlflow.data.http_dataset_source.HTTPDatasetSource, + # 'hf': mlflow.data.huggingface_dataset_source.HuggingFaceDatasetSource, + # 'delta_table': mlflow.data.delta_dataset_source.DeltaDatasetSource, + # 'uc_volume': mlflow.data.delta_dataset_source.UCVolumeDatasetSource, + # 'local': mlflow.data.http_dataset_source.HTTPDatasetSource, + # } + dataset_source_mapping = { 's3': mlflow.data.http_dataset_source.HTTPDatasetSource, 'oci': mlflow.data.http_dataset_source.HTTPDatasetSource, 'https': mlflow.data.http_dataset_source.HTTPDatasetSource, 'hf': mlflow.data.huggingface_dataset_source.HuggingFaceDatasetSource, - 'delta_table': mlflow.data.delta_dataset_source.DeltaDatasetSource, - 'uc_volume': UCVolumeDatasetSource, + 'delta_table': mlflow.data.http_dataset_source.HTTPDatasetSource, + 'uc_volume': mlflow.data.http_dataset_source.HTTPDatasetSource, 'local': mlflow.data.http_dataset_source.HTTPDatasetSource, } @@ -248,9 +259,7 @@ def log_dataset_uri(cfg: DictConfig) -> mlflow.data.meta_dataset.MetaDataset: source_class = dataset_source_mapping.get(dataset_type) if source_class: - if dataset_type == 'delta_table': - source = source_class(delta_table_name=path) - elif dataset_type == 'uc_volume' or dataset_type == 'hf': + if dataset_type == 'hf': source = source_class(path=path) else: source = source_class(url=path) @@ -258,4 +267,4 @@ def log_dataset_uri(cfg: DictConfig) -> mlflow.data.meta_dataset.MetaDataset: log.info(f'{dataset_type} unknown, defaulting to http dataset source') source =mlflow.data.http_dataset_source.HTTPDatasetSource(uri=path) - mlflow.log_input(mlflow.data.meta_dataset.MetaDataset(source, name=f'{split}')) \ No newline at end of file + mlflow.log_input(mlflow.data.meta_dataset.MetaDataset(source, name=f'{dataset_type} | {split}')) \ No newline at end of file diff --git a/llmfoundry/utils/uc_volume_dataset_source.py b/llmfoundry/utils/uc_volume_dataset_source.py index c06c55f4e4..a5f3ba6f27 100644 --- a/llmfoundry/utils/uc_volume_dataset_source.py +++ b/llmfoundry/utils/uc_volume_dataset_source.py @@ -23,7 +23,7 @@ class UCVolumeDatasetSource(DatasetSource): """ def __init__(self, path: str): - self._verify_uc_path_is_valid(path) + #self._verify_uc_path_is_valid(path) self.path = path def _verify_uc_path_is_valid(self, path):