Skip to content

Commit

Permalink
first pass
Browse files Browse the repository at this point in the history
  • Loading branch information
KuuCi committed Apr 19, 2024
1 parent 0d66d1d commit d286e15
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,19 @@ def log_config(cfg: DictConfig) -> None:
raise ImportError('MLflow is required but not installed.')
if mlflow.active_run():
mlflow.log_params(params=om.to_container(cfg, resolve=True))
log_dataset_uri(cfg)
_log_dataset_uri(cfg)


def parse_source_dataset(cfg: DictConfig):
"""Parse a run config for dataset information."""
def _parse_source_dataset(cfg: DictConfig):
"""Parse a run config for dataset information.
Given a config dictionary, parse through it to determine what the datasource
should be catagorized as. Possible data sources are Delta Tables, UC Volumes,
HuggingFace paths, remote storage, or local storage.
Args:
cfg (DictConfig): A config dictionary of a run
"""
data_paths = set()

for data_split in ['train', 'eval']:
Expand Down Expand Up @@ -236,13 +244,17 @@ def parse_source_dataset(cfg: DictConfig):
return data_paths


def log_dataset_uri(cfg: DictConfig):
"""Logs dataset tracking information to MLflow."""
def _log_dataset_uri(cfg: DictConfig):
"""Logs dataset tracking information to MLflow.
Args:
cfg (DictConfig): A config dictionary of a run
"""
if mlflow is None:
log.warning('MLflow is not installed. Skipping dataset logging.')
return None
raise ImportError('MLflow is not installed. Skipping dataset logging.')

# Figure out which data source to use
data_paths = parse_source_dataset(cfg)
data_paths = _parse_source_dataset(cfg)

dataset_source_mapping = {
's3': mlflow.data.http_dataset_source.HTTPDatasetSource,
Expand All @@ -254,6 +266,7 @@ def log_dataset_uri(cfg: DictConfig):
'local': mlflow.data.http_dataset_source.HTTPDatasetSource,
}

# Map data source types to their respective MLFlow DataSource.
for dataset_type, path, split in data_paths:
source_class = dataset_source_mapping.get(dataset_type)

Expand Down

0 comments on commit d286e15

Please sign in to comment.