Skip to content

Commit

Permalink
merged and resolved
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Apr 24, 2024
2 parents f2ed1d7 + 72da1d7 commit 063ab43
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 8 deletions.
142 changes: 134 additions & 8 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import copy
import logging
import math
import os
import warnings
from dataclasses import dataclass, fields
from typing import (Any, Callable, Dict, List, Literal, Mapping, Optional, Set,
Tuple, TypeVar, Union)

from composer.utils import dist
import mlflow
from composer.utils import dist, parse_uri
from omegaconf import MISSING, DictConfig, ListConfig, MissingMandatoryValue
from omegaconf import OmegaConf as om

Expand Down Expand Up @@ -440,10 +442,134 @@ def log_config(cfg: Dict[str, Any]) -> None:
if wandb.run:
wandb.config.update(cfg)

if 'mlflow' in cfg.get('loggers', {}):
try:
import mlflow
except ImportError as e:
raise e
if mlflow.active_run():
mlflow.log_params(params=cfg)
if 'mlflow' in cfg.get('loggers', {}) and mlflow.active_run():
mlflow.log_params(params=om.to_container(cfg, resolve=True))
_log_dataset_uri(cfg)


def _parse_source_dataset(cfg: Dict[str, Any]) -> List[Tuple[str, str, str]]:
"""Parse a run config for dataset information.
Given a config dictionary, parse through it to determine what the datasource
should be categorized as. Possible data sources are Delta Tables, UC Volumes,
HuggingFace paths, remote storage, or local storage.
Args:
cfg (DictConfig): A config dictionary of a run
Returns:
List[Tuple[str, str, str]]: A list of tuples formatted as (data type, path, split)
"""
data_paths = []

# Handle train loader if it exists
train_dataset: Dict = cfg.get('train_loader', {}).get('dataset', {})
train_split = train_dataset.get('split', None)
train_source_path = cfg.get('source_dataset_train', None)
_process_data_source(train_source_path, train_dataset, train_split, 'train',
data_paths)

# Handle eval_loader which might be a list or a single dictionary
eval_data_loaders = cfg.get('eval_loader', {})
if not isinstance(eval_data_loaders, list):
eval_data_loaders = [eval_data_loaders
] # Normalize to list if it's a single dictionary

for eval_data_loader in eval_data_loaders:
assert isinstance(eval_data_loader, dict) # pyright type check
eval_dataset: Dict = eval_data_loader.get('dataset', {})
eval_split = eval_dataset.get('split', None)
eval_source_path = cfg.get('source_dataset_eval', None)
_process_data_source(eval_source_path, eval_dataset, eval_split, 'eval',
data_paths)

return data_paths


def _process_data_source(source_dataset_path: Optional[str],
dataset: Dict[str, str], cfg_split: Optional[str],
true_split: str, data_paths: List[Tuple[str, str,
str]]):
"""Add a data source by mutating data_paths.
Given various dataset attributes, attempt to determine what type of dataset is being added, and parse
the dataset accordingly.
Args:
source_dataset_path (Optional[str]): The source dataset in cfg metadata
dataset (Dict[str, str]): The dataset from cfg
cfg_split (str): The split listed for the dataset in cfg
true_split (str): The split of the dataset to be added (i.e. train or eval)
data_paths (List[Tuple[str, str, str]]): A list of tuples formatted as (data type, path, split)
"""
# Check for Delta table
if source_dataset_path and len(source_dataset_path.split('.')) == 3:
data_paths.append(('delta_table', source_dataset_path, true_split))
# Check for UC volume
elif source_dataset_path and source_dataset_path.startswith('dbfs:'):
data_paths.append(
('uc_volume', source_dataset_path[len('dbfs:'):], true_split))
# Check for HF path
elif 'hf_name' in dataset:
hf_path = dataset['hf_name']
backend, _, _ = parse_uri(hf_path)
if backend:
hf_path = os.path.join(hf_path, cfg_split) if cfg_split else hf_path
data_paths.append((backend, hf_path, true_split))
elif os.path.exists(hf_path):
data_paths.append(('local', hf_path, true_split))
else:
data_paths.append(('hf', hf_path, true_split))
# Check for remote path
elif 'remote' in dataset:
remote_path = dataset['remote']
backend, _, _ = parse_uri(remote_path)
if backend:
remote_path = os.path.join(
remote_path, f'{cfg_split}/') if cfg_split else remote_path
data_paths.append((backend, remote_path, true_split))
else:
data_paths.append(('local', remote_path, true_split))
else:
log.warning('DataSource Not Found.')


def _log_dataset_uri(cfg: Dict[str, Any]) -> None:
"""Logs dataset tracking information to MLflow.
Args:
cfg (DictConfig): A config dictionary of a run
"""
# Figure out which data source to use
data_paths = _parse_source_dataset(cfg)

dataset_source_mapping = {
's3': mlflow.data.http_dataset_source.HTTPDatasetSource,
'oci': mlflow.data.http_dataset_source.HTTPDatasetSource,
'azure': mlflow.data.http_dataset_source.HTTPDatasetSource,
'gs': mlflow.data.http_dataset_source.HTTPDatasetSource,
'https': mlflow.data.http_dataset_source.HTTPDatasetSource,
'hf': mlflow.data.huggingface_dataset_source.HuggingFaceDatasetSource,
'delta_table': mlflow.data.delta_dataset_source.DeltaDatasetSource,
'uc_volume': mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource,
'local': mlflow.data.http_dataset_source.HTTPDatasetSource,
}

# Map data source types to their respective MLFlow DataSource.
for dataset_type, path, split in data_paths:

if dataset_type in dataset_source_mapping:
source_class = dataset_source_mapping[dataset_type]
if dataset_type == 'delta_table':
source = source_class(delta_table_name=path)
elif dataset_type == 'hf' or dataset_type == 'uc_volume':
source = source_class(path=path)
else:
source = source_class(url=path)
else:
log.info(
f'{dataset_type} unknown, defaulting to http dataset source')
source = mlflow.data.http_dataset_source.HTTPDatasetSource(url=path)

mlflow.log_input(
mlflow.data.meta_dataset.MetaDataset(source, name=split))
133 changes: 133 additions & 0 deletions tests/utils/test_mlflow_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any
from unittest.mock import MagicMock, patch

import pytest
from omegaconf import OmegaConf

from llmfoundry.utils.config_utils import (_log_dataset_uri,
_parse_source_dataset)

mlflow = pytest.importorskip('mlflow')
from mlflow.data.huggingface_dataset_source import HuggingFaceDatasetSource


def create_config(**kwargs: Any):
"""Helper function to create OmegaConf configurations."""
return OmegaConf.create(kwargs)


def test_parse_source_dataset_delta_table():
cfg = create_config(source_dataset_train='db.schema.train_table',
source_dataset_eval='db.schema.eval_table')
expected = [('delta_table', 'db.schema.train_table', 'train'),
('delta_table', 'db.schema.eval_table', 'eval')]
assert _parse_source_dataset(cfg) == expected


def test_parse_source_dataset_uc_volume():
cfg = create_config(source_dataset_train='dbfs:/Volumes/train_data',
source_dataset_eval='dbfs:/Volumes/eval_data')
expected = [('uc_volume', '/Volumes/train_data', 'train'),
('uc_volume', '/Volumes/eval_data', 'eval')]
assert _parse_source_dataset(cfg) == expected


def test_parse_source_dataset_hf():
cfg = create_config(
train_loader={'dataset': {
'hf_name': 'huggingface/train_dataset',
}},
eval_loader={'dataset': {
'hf_name': 'huggingface/eval_dataset',
}})
expected = [('hf', 'huggingface/train_dataset', 'train'),
('hf', 'huggingface/eval_dataset', 'eval')]
assert _parse_source_dataset(cfg) == expected


def test_parse_source_dataset_remote():
cfg = create_config(train_loader={
'dataset': {
'remote': 'https://remote/train_dataset',
'split': 'train'
}
},
eval_loader={
'dataset': {
'remote': 'https://remote/eval_dataset',
'split': 'eval'
}
})
expected = [('https', 'https://remote/train_dataset/train/', 'train'),
('https', 'https://remote/eval_dataset/eval/', 'eval')]
assert _parse_source_dataset(cfg) == expected


def test_log_dataset_uri():
cfg = create_config(
train_loader={'dataset': {
'hf_name': 'huggingface/train_dataset'
}},
eval_loader={'dataset': {
'hf_name': 'huggingface/eval_dataset'
}},
source_dataset_train='huggingface/train_dataset',
source_dataset_eval='huggingface/eval_dataset')

with patch('mlflow.log_input') as mock_log_input:
_log_dataset_uri(cfg)
assert mock_log_input.call_count == 2
meta_dataset_calls = [
args[0] for args, _ in mock_log_input.call_args_list
]
assert all(
isinstance(call.source, HuggingFaceDatasetSource)
for call in meta_dataset_calls), 'Source types are incorrect'
# Verify the names
assert meta_dataset_calls[
0].name == 'train', f"Expected 'train', got {meta_dataset_calls[0].name}"
assert meta_dataset_calls[
1].name == 'eval', f"Expected 'eval', got {meta_dataset_calls[1].name}"


def test_multiple_eval_datasets():
# Setup a configuration with multiple evaluation datasets
cfg = OmegaConf.create({
'train_loader': {
'dataset': {
'hf_name': 'huggingface/train_dataset',
},
},
'eval_loader': [{
'dataset': {
'hf_name': 'huggingface/eval_dataset1',
},
}, {
'dataset': {
'hf_name': 'huggingface/eval_dataset2',
},
}]
})

expected_data_paths = [('hf', 'huggingface/train_dataset', 'train'),
('hf', 'huggingface/eval_dataset1', 'eval'),
('hf', 'huggingface/eval_dataset2', 'eval')]

# Mock mlflow to avoid any actual logging calls
with patch('mlflow.data.meta_dataset.MetaDataset') as mock_meta_dataset:
mock_meta_dataset.side_effect = lambda source, name: MagicMock()
data_paths = _parse_source_dataset(cfg)
assert sorted(data_paths) == sorted(
expected_data_paths), 'Data paths did not match expected'


@pytest.fixture
def mock_mlflow_classes():
with patch('mlflow.data.http_dataset_source.HTTPDatasetSource') as http_source, \
patch('mlflow.data.huggingface_dataset_source.HuggingFaceDatasetSource') as hf_source, \
patch('mlflow.data.delta_dataset_source.DeltaDatasetSource') as delta_source, \
patch('mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource') as uc_source:
yield http_source, hf_source, delta_source, uc_source

0 comments on commit 063ab43

Please sign in to comment.