Skip to content

Commit

Permalink
more code fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KuuCi committed Apr 18, 2024
1 parent aae821d commit 996fb01
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import logging
import math
import warnings
import mlflow
from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Union

import mlflow
from composer.utils import dist, parse_uri
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
Expand Down Expand Up @@ -187,9 +187,11 @@ def log_config(cfg: DictConfig) -> None:
mlflow.log_params(params=om.to_container(cfg, resolve=True))
log_dataset_uri(cfg)


def parse_source_dataset(cfg: DictConfig):
"""
This function parses a run config for dataset information related to training and evaluation stages.
Parses a run config for dataset information related to training and evaluation stages.
It supports extracting paths from different sources including local filesystem, remote locations, Hugging Face datasets,
Delta tables, and UC volume paths. The function aggregates unique dataset identifiers and their types from the configuration.
Expand All @@ -203,7 +205,8 @@ def parse_source_dataset(cfg: DictConfig):
data_paths = set()

for data_split in ['train', 'eval']:
split = cfg.get(f'{data_split}_loader', {}).get('dataset', {}).get('split', None)
split = cfg.get(f'{data_split}_loader', {}).get('dataset',
{}).get('split', None)
source_dataset_path = cfg.get(f'source_dataset_{data_split}', {})

# check for Delta table
Expand All @@ -213,28 +216,36 @@ def parse_source_dataset(cfg: DictConfig):
elif source_dataset_path and source_dataset_path.startswith('/Volumes'):
data_paths.add(('uc_volume', source_dataset_path, data_split))
# check for HF path
elif cfg.get(f'{data_split}_loader', {}).get('dataset', {}).get('hf_name'):
hf_path = cfg.get(f'{data_split}_loader', {}).get('dataset', {}).get('hf_name')
elif cfg.get(f'{data_split}_loader', {}).get('dataset',
{}).get('hf_name'):
hf_path = cfg.get(f'{data_split}_loader', {}).get('dataset',
{}).get('hf_name')
backend, _, _ = parse_uri(hf_path)
if backend:
hf_path = f'{hf_path.rstrip("/")}/{split}/' if split else hf_path
data_paths.add((backend, hf_path, data_split))
else:
data_paths.add(('hf', hf_path, data_split))
# check for remote path
elif cfg.get(f'{data_split}_loader', {}).get('dataset', {}).get('remote', None):
remote_path = cfg.get(f'{data_split}_loader', {}).get('dataset', {}).get('remote', None)
elif cfg.get(f'{data_split}_loader', {}).get('dataset',
{}).get('remote', None):
remote_path = cfg.get(f'{data_split}_loader',
{}).get('dataset', {}).get('remote', None)
backend, _, _ = parse_uri(remote_path)
remote_path = f'{remote_path.rstrip("/")}/{split}/' if split else remote_path
data_paths.add((backend, remote_path, data_split))
# check for local path
elif cfg.get(f'{data_split}_loader', {}).get('dataset', {}).get('local', None):
local_path = cfg.get(f'{data_split}_loader', {}).get('dataset', {}).get('local', None)
split = cfg.get(f'{data_split}_loader', {}).get('dataset', {}).get('split', None)
elif cfg.get(f'{data_split}_loader', {}).get('dataset',
{}).get('local', None):
local_path = cfg.get(f'{data_split}_loader',
{}).get('dataset', {}).get('local', None)
split = cfg.get(f'{data_split}_loader',
{}).get('dataset', {}).get('split', None)
data_paths.add(('local', local_path, data_split))

return data_paths


def log_dataset_uri(cfg: DictConfig) -> mlflow.data.meta_dataset.MetaDataset:
"""
Extracts dataset information from the provided configuration and translates it into
Expand All @@ -261,7 +272,7 @@ def log_dataset_uri(cfg: DictConfig) -> mlflow.data.meta_dataset.MetaDataset:

for dataset_type, path, split in data_paths:
source_class = dataset_source_mapping.get(dataset_type)

if source_class:
if dataset_type == 'delta_table':
source = source_class(delta_table_name=path)
Expand All @@ -270,7 +281,10 @@ def log_dataset_uri(cfg: DictConfig) -> mlflow.data.meta_dataset.MetaDataset:
else:
source = source_class(url=path)
else:
log.info(f'{dataset_type} unknown, defaulting to http dataset source')
source =mlflow.data.http_dataset_source.HTTPDatasetSource(uri=path)
log.info(
f'{dataset_type} unknown, defaulting to http dataset source')
source = mlflow.data.http_dataset_source.HTTPDatasetSource(uri=path)

mlflow.log_input(mlflow.data.meta_dataset.MetaDataset(source, name=f'{dataset_type} | {split}'))
mlflow.log_input(
mlflow.data.meta_dataset.MetaDataset(
source, name=f'{dataset_type} | {split}'))

0 comments on commit 996fb01

Please sign in to comment.