Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Sep 14, 2023
1 parent 7d1b5c3 commit c06a40b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
19 changes: 9 additions & 10 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,15 @@ def parse_loss_from_log(log: str) -> List[float]:
return losses

@staticmethod
def check_that_loss_is_decreasing(losses: List[float], steps: int) -> bool:
def check_that_loss_is_decreasing(losses: List[float], steps: int) -> Tuple[bool, List[float], List[float]]:
mean_losses = []
num_mean_losses = len(losses) // steps
for i in range(num_mean_losses):
mean = sum(losses[i * steps : (i + 1) * steps]) / steps
mean_losses.append(mean)

return mean_losses == sorted(mean_losses, reverse=True)
expected_mean_losses = sorted(mean_losses, reverse=True)
return mean_losses == expected_mean_losses, mean_losses, expected_mean_losses

@classmethod
def _create_test(
Expand Down Expand Up @@ -366,15 +367,19 @@ def test(self):
disable_embedding_parallelization=disable_embedding_parallelization,
zero_1=zero_1,
output_dir=tmpdirname,
# TODO: enable precompilation once it's working with subprocess.
do_precompilation=True,
print_outputs=True,
)
assert returncode == 0

if self.CHECK_THAT_LOSS_IS_DECREASING:
losses = ExampleTestMeta.parse_loss_from_log(stdout)
assert ExampleTestMeta.check_that_loss_is_decreasing(losses, 20)
is_decreasing, mean_losses, expected_mean_losses = ExampleTestMeta.check_that_loss_is_decreasing(
losses, 50
)
self.assertTrue(
is_decreasing, f"Expected mean losses to be {expected_mean_losses} but got {mean_losses}"
)

if self.DO_EVAL:
with open(Path(tmpdirname) / "all_results.json") as fp:
Expand All @@ -388,12 +393,6 @@ def test(self):
else:
self.assertLessEqual(float(results[self.SCORE_NAME]), eval_score_threshold)

# train_loss_threshold = (
# self.TRAIN_LOSS_THRESHOLD if not RUN_TINY else self.TRAIN_LOSS_THRESHOLD_FOR_TINY
# )
# train_loss_threshold = ExampleTestMeta.process_class_attribute(train_loss_threshold, model_type)
# self.assertLessEqual(float(results["train_loss"]), train_loss_threshold)

return test


Expand Down
4 changes: 1 addition & 3 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def tearDownClass(cls) -> None:
@parameterized.expand(TO_TEST)
def test_run_example(self, task, model_name_or_path, sequence_length):
runner = ExampleRunner(model_name_or_path, task)
returncode, stdout = runner.run(
1, "bf16", 1, sequence_length=sequence_length, max_steps=10, save_steps=5
)
returncode, stdout = runner.run(1, "bf16", 1, sequence_length=sequence_length, max_steps=10, save_steps=5)
print(stdout)
if returncode != 0:
self.fail(f"ExampleRunner failed for task {task}.\nStandard output:\n{stdout}")

0 comments on commit c06a40b

Please sign in to comment.