From 764ccecdd75b27a6872fd8fb54be6156828cf0a8 Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Fri, 22 Mar 2024 01:28:01 -0700 Subject: [PATCH] datasource --- llmfoundry/utils/config_utils.py | 60 +++++++++++++++++--------------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 302b764cab..088975ba83 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -180,40 +180,44 @@ def parse_source_dataset(cfg: DictConfig): cfg (DictConfig): run configuration Returns: - List[Tuple[str, str]]: A set of tuples where each tuple represents a dataset type ('local', 'hf', 'delta_table', 'uc_volume', + Set[Tuple[str, str]]: A set of tuples where each tuple represents a dataset type ('local', 'hf', 'delta_table', 'uc_volume', remote backend) and the corresponding dataset path or identifier. """ - paths = set() - data_paths = [] + data_paths = set() for data_split in ['train', 'eval']: source_dataset_path = cfg.get(f'source_dataset_{data_split}', {}) - backend, _, _ = parse_uri(source_dataset_path) - - if backend: - remote_path = source_dataset_path - - if source_dataset_path and source_dataset_path not in paths: - # check for Delta table - if len(source_dataset_path.split('.')) >= 3: - data_paths.append(('delta_table', source_dataset_path)) - # check for UC volume - elif source_dataset_path.startswith('/Volumes'): - data_paths.append(('uc_volume', source_dataset_path)) - # check for HF path - elif cfg.get(f'{data_split}_loader', {}).get('dataset', {}).get('hf_name'): - hf_path = cfg.get(f'{data_split}_loader', {}).get('dataset', {}).get('hf_name') - backend, _, _ = parse_uri(hf_path) - if backend: - data_paths.append((backend, hf_path)) - else: - data_paths.append(('hf', hf_path)) - # check for Remote path - elif backend: - data_paths.append((backend, source_dataset_path)) + + # check for Delta table + if source_dataset_path and len(source_dataset_path.split('.')) >= 3: + data_paths.add(('delta_table', source_dataset_path)) + # check for UC volume + elif source_dataset_path and source_dataset_path.startswith('/Volumes'): + data_paths.add(('uc_volume', source_dataset_path)) + # check for HF path + elif cfg.get(f'{data_split}_loader', {}).get('dataset', {}).get('hf_name'): + hf_path = cfg.get(f'{data_split}_loader', {}).get('dataset', {}).get('hf_name') + backend, _, _ = parse_uri(hf_path) + if backend: + data_paths.add((backend, hf_path)) else: - data_paths.append(('local', source_dataset_path)) - paths.add(source_dataset_path) + data_paths.add(('hf', hf_path)) + # check for remote path + elif config.get(f'{data_split}_loader', {}).get('dataset', {}).get('remote', None): + remote_path = config.get(f'{data_split}_loader', {}).get('dataset', {}).get('remote', None) + backend, _, _ = parse_uri(remote_path) + split = config.get(f'{data_split}_loader', {}).get('dataset', {}).get('split', None) + remote_path = f'{remote_path}/{split}' if split + data_paths.add((backend, remote_path)) + # check for local path + elif config.get(f'{data_split}_loader', {}).get('dataset', {}).get('local', None): + local_path = config.get(f'{data_split}_loader', {}).get('dataset', {}).get('local', None) + split = config.get(f'{data_split}_loader', {}).get('dataset', {}).get('split', None) + remote_path = f'{remote_path}/{split}' if split + data_paths.add(('local', local_path)) + + paths.add(source_dataset_path) + print(f'---- PATHS {data_paths} ----') return data_paths