From 80acfb39faf10393e68bfe4fb17f24d24d82882c Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Tue, 23 Apr 2024 14:47:00 +0000 Subject: [PATCH] just because omega starts with OMMMM does not mean it's zen --- llmfoundry/data/text_data.py | 3 +-- tests/data/test_dataloader.py | 6 ++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 61daac1165..47d4709eee 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -14,7 +14,6 @@ import transformers from composer.core.data_spec import DataSpec from composer.core.types import Batch -from omegaconf import OmegaConf as om from streaming import Stream, StreamingDataset from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase @@ -274,6 +273,7 @@ def build_text_dataloader( persistent_workers: bool = True, timeout: int = 0, ) -> DataSpec: + dataset_cfg = dataset # get kwargs @@ -450,7 +450,6 @@ def get_num_samples_in_batch(batch: Batch) -> int: 'drop_last': False, 'num_workers': 4, } - cfg = om.create(cfg) device_batch_size = 2 tokenizer_name = args.tokenizer diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 1b141af6b2..e584c8c11e 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -1119,7 +1119,7 @@ def test_token_counting_func_dataloader_setting( device_batch_size=batch_size, **cfg) elif dataloader_type == 'text': - cfg = DictConfig({ + cfg = { 'name': 'text', 'dataset': { 'local': 'dummy-path', @@ -1130,7 +1130,7 @@ def test_token_counting_func_dataloader_setting( 'shuffle_seed': 0, }, **common_args - }) + } ds_mock = MagicMock() ds_mock.tokenizer = gptt monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset', @@ -1142,8 +1142,6 @@ def test_token_counting_func_dataloader_setting( else: raise NotImplementedError() - cfg = om.create(cfg) - batch_collated = dl.dataloader.collate_fn(batch_tokenized) # type: ignore actual_token_count = dl.get_num_tokens_in_batch(batch_collated)