Skip to content

Commit

Permalink
more quality
Browse files Browse the repository at this point in the history
  • Loading branch information
KuuCi committed Apr 19, 2024
1 parent f3d8348 commit c44fafa
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/utils/test_mlflow_logging.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from unittest.mock import MagicMock, patch
from unittest.mock import patch

import pytest
from typing import Any
from omegaconf import OmegaConf

from llmfoundry.utils.config_utils import log_dataset_uri, parse_source_dataset

mlflow = pytest.importorskip('mlflow')


def create_config(**kwargs):
def create_config(**kwargs: Any):
"""Helper function to create OmegaConf configurations."""
return OmegaConf.create(kwargs)

Expand Down Expand Up @@ -72,7 +73,7 @@ def test_parse_source_dataset_local():


@pytest.mark.usefixtures('mock_mlflow_classes')
def test_log_dataset_uri_all_sources(mock_mlflow_classes):
def test_log_dataset_uri_all_sources():
cfg = create_config(
train_loader={'dataset': {
'hf_name': 'huggingface/train_dataset'
Expand All @@ -83,7 +84,7 @@ def test_log_dataset_uri_all_sources(mock_mlflow_classes):
source_dataset_train='db.schema.train_table',
source_dataset_eval='/Volumes/eval_data')

with patch('mlflow.data.meta_dataset.MetaDataset') as mock_meta:
with patch('mlflow.data.meta_dataset.MetaDataset'):
with patch('mlflow.log_input') as mock_log_input:
log_dataset_uri(cfg)
assert mock_log_input.call_count == 2
Expand Down

0 comments on commit c44fafa

Please sign in to comment.