Skip to content

Commit

Permalink
Improve filepath handling in unit tests (#1737)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Sep 24, 2024
1 parent 6fc1f06 commit 9fe28a5
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 22 deletions.
2 changes: 2 additions & 0 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@


def init_out_dir(out_dir: Path) -> Path:
if not isinstance(out_dir, Path):
out_dir = Path(out_dir)
if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ:
return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir
return out_dir
Expand Down
32 changes: 19 additions & 13 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ def test_chat_with_quantized_model():
@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
@pytest.mark.dependency(depends=["test_download_model"])
@pytest.mark.timeout(300)
def test_finetune_model():
def test_finetune_model(tmp_path):

OUT_DIR = Path("out") / "lora"
DATASET_PATH = Path("custom_finetuning_dataset.json")
OUT_DIR = tmp_path / "out" / "lora"
DATASET_PATH = tmp_path / "custom_finetuning_dataset.json"
CHECKPOINT_DIR = "checkpoints" / REPO_ID

download_command = ["curl", "-L", "https://huggingface.co/datasets/medalpaca/medical_meadow_health_advice/raw/main/medical_meadow_health_advice.json", "-o", str(DATASET_PATH)]
Expand All @@ -105,8 +105,10 @@ def test_finetune_model():
]
run_command(finetune_command)

assert (OUT_DIR/"final").exists(), "Finetuning output directory was not created"
assert (OUT_DIR/"final"/"lit_model.pth").exists(), "Model file was not created"
generated_out_dir = OUT_DIR/"final"
assert generated_out_dir.exists(), f"Finetuning output directory ({generated_out_dir}) was not created"
model_file = OUT_DIR/"final"/"lit_model.pth"
assert model_file.exists(), f"Model file ({model_file}) was not created"


@pytest.mark.skipif(
Expand All @@ -116,8 +118,8 @@ def test_finetune_model():
)
@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
@pytest.mark.dependency(depends=["test_download_model", "test_download_books"])
def test_pretrain_model():
OUT_DIR = Path("out") / "custom_pretrained"
def test_pretrain_model(tmp_path):
OUT_DIR = tmp_path / "out" / "custom_pretrained"
pretrain_command = [
"litgpt", "pretrain",
"pythia-14m",
Expand All @@ -131,8 +133,10 @@ def test_pretrain_model():
output = run_command(pretrain_command)

assert "Warning: Preprocessed training data found" not in output
assert (OUT_DIR / "final").exists(), "Pretraining output directory was not created"
assert (OUT_DIR / "final" / "lit_model.pth").exists(), "Model file was not created"
out_dir_path = OUT_DIR / "final"
assert out_dir_path.exists(), f"Pretraining output directory ({out_dir_path}) was not created"
out_model_path = OUT_DIR / "final" / "lit_model.pth"
assert out_model_path.exists(), f"Model file ({out_model_path}) was not created"

# Test that warning is displayed when running it a second time
output = run_command(pretrain_command)
Expand All @@ -146,8 +150,8 @@ def test_pretrain_model():
)
@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
@pytest.mark.dependency(depends=["test_download_model", "test_download_books"])
def test_continue_pretrain_model():
OUT_DIR = Path("out") / "custom_continue_pretrained"
def test_continue_pretrain_model(tmp_path):
OUT_DIR = tmp_path / "out" / "custom_continue_pretrained"
pretrain_command = [
"litgpt", "pretrain",
"pythia-14m",
Expand All @@ -161,8 +165,10 @@ def test_continue_pretrain_model():
]
run_command(pretrain_command)

assert (OUT_DIR / "final").exists(), "Continued pretraining output directory was not created"
assert (OUT_DIR / "final" / "lit_model.pth").exists(), "Model file was not created"
generated_out_dir = OUT_DIR/"final"
assert generated_out_dir.exists(), f"Continued pretraining directory ({generated_out_dir}) was not created"
model_file = OUT_DIR/"final"/"lit_model.pth"
assert model_file.exists(), f"Model file ({model_file}) was not created"


@pytest.mark.dependency(depends=["test_download_model"])
Expand Down
26 changes: 17 additions & 9 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,23 @@ def test_choose_logger(tmp_path):
choose_logger("foo", out_dir=tmp_path, name="foo")


def test_init_out_dir(tmp_path):
relative_path = Path("./out")
absolute_path = tmp_path / "out"
assert init_out_dir(relative_path) == relative_path
assert init_out_dir(absolute_path) == absolute_path

with mock.patch.dict(os.environ, {"LIGHTNING_ARTIFACTS_DIR": "prefix"}):
assert init_out_dir(relative_path) == Path("prefix") / relative_path
assert init_out_dir(absolute_path) == absolute_path
@pytest.mark.parametrize("path_type, input_path, expected", [
("relative", "some/relative/path", "some/relative/path"),
("absolute", "/usr/absolute/path", "/usr/absolute/path"),
("env_relative", "some/relative/path", "prefix/some/relative/path"),
("env_absolute", "/usr/absolute/path", "/usr/absolute/path")
])
def test_init_out_dir(path_type, input_path, expected):
if path_type.startswith("env_"):
with mock.patch.dict(os.environ, {"LIGHTNING_ARTIFACTS_DIR": "prefix"}):
result = init_out_dir(input_path)
assert result == Path(expected), f"Failed for {path_type} with input {input_path} (result {result})"
else:
result = init_out_dir(input_path)
if "LIGHTNING_ARTIFACTS_DIR" not in os.environ:
assert result == Path(expected), f"Failed for {path_type} with input {input_path} (result {result})"
else:
assert result == Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / expected, f"Failed for {path_type} with input {input_path} (result {result})"


def test_find_resume_path(tmp_path):
Expand Down

0 comments on commit 9fe28a5

Please sign in to comment.