diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index f460da8344..905d99c79b 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -5,12 +5,14 @@ import copy import logging import math +import os import warnings from dataclasses import dataclass, fields from typing import (Any, Callable, Dict, List, Literal, Mapping, Optional, Set, Tuple, TypeVar, Union) -from composer.utils import dist +import mlflow +from composer.utils import dist, parse_uri from omegaconf import MISSING, DictConfig, ListConfig, MissingMandatoryValue from omegaconf import OmegaConf as om @@ -440,10 +442,134 @@ def log_config(cfg: Dict[str, Any]) -> None: if wandb.run: wandb.config.update(cfg) - if 'mlflow' in cfg.get('loggers', {}): - try: - import mlflow - except ImportError as e: - raise e - if mlflow.active_run(): - mlflow.log_params(params=cfg) + if 'mlflow' in cfg.get('loggers', {}) and mlflow.active_run(): + mlflow.log_params(params=om.to_container(cfg, resolve=True)) + _log_dataset_uri(cfg) + + +def _parse_source_dataset(cfg: Dict[str, Any]) -> List[Tuple[str, str, str]]: + """Parse a run config for dataset information. + + Given a config dictionary, parse through it to determine what the datasource + should be categorized as. Possible data sources are Delta Tables, UC Volumes, + HuggingFace paths, remote storage, or local storage. + + Args: + cfg (DictConfig): A config dictionary of a run + + Returns: + List[Tuple[str, str, str]]: A list of tuples formatted as (data type, path, split) + """ + data_paths = [] + + # Handle train loader if it exists + train_dataset: Dict = cfg.get('train_loader', {}).get('dataset', {}) + train_split = train_dataset.get('split', None) + train_source_path = cfg.get('source_dataset_train', None) + _process_data_source(train_source_path, train_dataset, train_split, 'train', + data_paths) + + # Handle eval_loader which might be a list or a single dictionary + eval_data_loaders = cfg.get('eval_loader', {}) + if not isinstance(eval_data_loaders, list): + eval_data_loaders = [eval_data_loaders + ] # Normalize to list if it's a single dictionary + + for eval_data_loader in eval_data_loaders: + assert isinstance(eval_data_loader, dict) # pyright type check + eval_dataset: Dict = eval_data_loader.get('dataset', {}) + eval_split = eval_dataset.get('split', None) + eval_source_path = cfg.get('source_dataset_eval', None) + _process_data_source(eval_source_path, eval_dataset, eval_split, 'eval', + data_paths) + + return data_paths + + +def _process_data_source(source_dataset_path: Optional[str], + dataset: Dict[str, str], cfg_split: Optional[str], + true_split: str, data_paths: List[Tuple[str, str, + str]]): + """Add a data source by mutating data_paths. + + Given various dataset attributes, attempt to determine what type of dataset is being added, and parse + the dataset accordingly. + + Args: + source_dataset_path (Optional[str]): The source dataset in cfg metadata + dataset (Dict[str, str]): The dataset from cfg + cfg_split (str): The split listed for the dataset in cfg + true_split (str): The split of the dataset to be added (i.e. train or eval) + data_paths (List[Tuple[str, str, str]]): A list of tuples formatted as (data type, path, split) + """ + # Check for Delta table + if source_dataset_path and len(source_dataset_path.split('.')) == 3: + data_paths.append(('delta_table', source_dataset_path, true_split)) + # Check for UC volume + elif source_dataset_path and source_dataset_path.startswith('dbfs:'): + data_paths.append( + ('uc_volume', source_dataset_path[len('dbfs:'):], true_split)) + # Check for HF path + elif 'hf_name' in dataset: + hf_path = dataset['hf_name'] + backend, _, _ = parse_uri(hf_path) + if backend: + hf_path = os.path.join(hf_path, cfg_split) if cfg_split else hf_path + data_paths.append((backend, hf_path, true_split)) + elif os.path.exists(hf_path): + data_paths.append(('local', hf_path, true_split)) + else: + data_paths.append(('hf', hf_path, true_split)) + # Check for remote path + elif 'remote' in dataset: + remote_path = dataset['remote'] + backend, _, _ = parse_uri(remote_path) + if backend: + remote_path = os.path.join( + remote_path, f'{cfg_split}/') if cfg_split else remote_path + data_paths.append((backend, remote_path, true_split)) + else: + data_paths.append(('local', remote_path, true_split)) + else: + log.warning('DataSource Not Found.') + + +def _log_dataset_uri(cfg: Dict[str, Any]) -> None: + """Logs dataset tracking information to MLflow. + + Args: + cfg (DictConfig): A config dictionary of a run + """ + # Figure out which data source to use + data_paths = _parse_source_dataset(cfg) + + dataset_source_mapping = { + 's3': mlflow.data.http_dataset_source.HTTPDatasetSource, + 'oci': mlflow.data.http_dataset_source.HTTPDatasetSource, + 'azure': mlflow.data.http_dataset_source.HTTPDatasetSource, + 'gs': mlflow.data.http_dataset_source.HTTPDatasetSource, + 'https': mlflow.data.http_dataset_source.HTTPDatasetSource, + 'hf': mlflow.data.huggingface_dataset_source.HuggingFaceDatasetSource, + 'delta_table': mlflow.data.delta_dataset_source.DeltaDatasetSource, + 'uc_volume': mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource, + 'local': mlflow.data.http_dataset_source.HTTPDatasetSource, + } + + # Map data source types to their respective MLFlow DataSource. + for dataset_type, path, split in data_paths: + + if dataset_type in dataset_source_mapping: + source_class = dataset_source_mapping[dataset_type] + if dataset_type == 'delta_table': + source = source_class(delta_table_name=path) + elif dataset_type == 'hf' or dataset_type == 'uc_volume': + source = source_class(path=path) + else: + source = source_class(url=path) + else: + log.info( + f'{dataset_type} unknown, defaulting to http dataset source') + source = mlflow.data.http_dataset_source.HTTPDatasetSource(url=path) + + mlflow.log_input( + mlflow.data.meta_dataset.MetaDataset(source, name=split)) diff --git a/tests/utils/test_mlflow_logging.py b/tests/utils/test_mlflow_logging.py new file mode 100644 index 0000000000..b8dd0becdf --- /dev/null +++ b/tests/utils/test_mlflow_logging.py @@ -0,0 +1,133 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from omegaconf import OmegaConf + +from llmfoundry.utils.config_utils import (_log_dataset_uri, + _parse_source_dataset) + +mlflow = pytest.importorskip('mlflow') +from mlflow.data.huggingface_dataset_source import HuggingFaceDatasetSource + + +def create_config(**kwargs: Any): + """Helper function to create OmegaConf configurations.""" + return OmegaConf.create(kwargs) + + +def test_parse_source_dataset_delta_table(): + cfg = create_config(source_dataset_train='db.schema.train_table', + source_dataset_eval='db.schema.eval_table') + expected = [('delta_table', 'db.schema.train_table', 'train'), + ('delta_table', 'db.schema.eval_table', 'eval')] + assert _parse_source_dataset(cfg) == expected + + +def test_parse_source_dataset_uc_volume(): + cfg = create_config(source_dataset_train='dbfs:/Volumes/train_data', + source_dataset_eval='dbfs:/Volumes/eval_data') + expected = [('uc_volume', '/Volumes/train_data', 'train'), + ('uc_volume', '/Volumes/eval_data', 'eval')] + assert _parse_source_dataset(cfg) == expected + + +def test_parse_source_dataset_hf(): + cfg = create_config( + train_loader={'dataset': { + 'hf_name': 'huggingface/train_dataset', + }}, + eval_loader={'dataset': { + 'hf_name': 'huggingface/eval_dataset', + }}) + expected = [('hf', 'huggingface/train_dataset', 'train'), + ('hf', 'huggingface/eval_dataset', 'eval')] + assert _parse_source_dataset(cfg) == expected + + +def test_parse_source_dataset_remote(): + cfg = create_config(train_loader={ + 'dataset': { + 'remote': 'https://remote/train_dataset', + 'split': 'train' + } + }, + eval_loader={ + 'dataset': { + 'remote': 'https://remote/eval_dataset', + 'split': 'eval' + } + }) + expected = [('https', 'https://remote/train_dataset/train/', 'train'), + ('https', 'https://remote/eval_dataset/eval/', 'eval')] + assert _parse_source_dataset(cfg) == expected + + +def test_log_dataset_uri(): + cfg = create_config( + train_loader={'dataset': { + 'hf_name': 'huggingface/train_dataset' + }}, + eval_loader={'dataset': { + 'hf_name': 'huggingface/eval_dataset' + }}, + source_dataset_train='huggingface/train_dataset', + source_dataset_eval='huggingface/eval_dataset') + + with patch('mlflow.log_input') as mock_log_input: + _log_dataset_uri(cfg) + assert mock_log_input.call_count == 2 + meta_dataset_calls = [ + args[0] for args, _ in mock_log_input.call_args_list + ] + assert all( + isinstance(call.source, HuggingFaceDatasetSource) + for call in meta_dataset_calls), 'Source types are incorrect' + # Verify the names + assert meta_dataset_calls[ + 0].name == 'train', f"Expected 'train', got {meta_dataset_calls[0].name}" + assert meta_dataset_calls[ + 1].name == 'eval', f"Expected 'eval', got {meta_dataset_calls[1].name}" + + +def test_multiple_eval_datasets(): + # Setup a configuration with multiple evaluation datasets + cfg = OmegaConf.create({ + 'train_loader': { + 'dataset': { + 'hf_name': 'huggingface/train_dataset', + }, + }, + 'eval_loader': [{ + 'dataset': { + 'hf_name': 'huggingface/eval_dataset1', + }, + }, { + 'dataset': { + 'hf_name': 'huggingface/eval_dataset2', + }, + }] + }) + + expected_data_paths = [('hf', 'huggingface/train_dataset', 'train'), + ('hf', 'huggingface/eval_dataset1', 'eval'), + ('hf', 'huggingface/eval_dataset2', 'eval')] + + # Mock mlflow to avoid any actual logging calls + with patch('mlflow.data.meta_dataset.MetaDataset') as mock_meta_dataset: + mock_meta_dataset.side_effect = lambda source, name: MagicMock() + data_paths = _parse_source_dataset(cfg) + assert sorted(data_paths) == sorted( + expected_data_paths), 'Data paths did not match expected' + + +@pytest.fixture +def mock_mlflow_classes(): + with patch('mlflow.data.http_dataset_source.HTTPDatasetSource') as http_source, \ + patch('mlflow.data.huggingface_dataset_source.HuggingFaceDatasetSource') as hf_source, \ + patch('mlflow.data.delta_dataset_source.DeltaDatasetSource') as delta_source, \ + patch('mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource') as uc_source: + yield http_source, hf_source, delta_source, uc_source