Skip to content

Commit

Permalink
fix failing model loading tests
Browse files Browse the repository at this point in the history
turns out that the outputs were actually equal to within tolerances,
but argsort for large enough vocabularies fails (since weight recovery is not perfect and introduces some error)
  • Loading branch information
mivanit committed Aug 20, 2024
1 parent 9e7b888 commit 3a589b1
Showing 1 changed file with 15 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@
(
"raster",
MazeTokenizer(
tokenization_mode=TokenizationMode.AOTP_UT_rasterized, max_grid_size=10
tokenization_mode=TokenizationMode.AOTP_UT_rasterized, max_grid_size=5
),
),
(
"uniform",
MazeTokenizer(
tokenization_mode=TokenizationMode.AOTP_UT_uniform, max_grid_size=10
tokenization_mode=TokenizationMode.AOTP_UT_uniform, max_grid_size=5
),
),
(
"indexed",
MazeTokenizer(
tokenization_mode=TokenizationMode.AOTP_CTT_indexed, max_grid_size=10
tokenization_mode=TokenizationMode.AOTP_CTT_indexed, max_grid_size=5
),
),
("modular", MazeTokenizerModular()), # only checking default for now
Expand Down Expand Up @@ -106,7 +106,12 @@ def test_model_save_fold_ln(cfg_model: tuple[ConfigHolder, ZanjHookedTransformer
zanj.save(model, fname)
model_load = zanj.read(fname)

assert_model_output_equality(model, model_load)
vocab_size: int = len(model.zanj_model_config.tokenizer)
assert_model_output_equality(
model,
model_load,
check_argsort_equality=(vocab_size > 2048),
)


@pytest.mark.parametrize("cfg_model", MODELS, ids=lambda x: x[0].name)
Expand All @@ -131,4 +136,9 @@ def test_model_save_refactored_attn_matrices(
zanj.save(model, fname)
model_load = zanj.read(fname)

assert_model_output_equality(model, model_load)
vocab_size: int = len(model.zanj_model_config.tokenizer)
assert_model_output_equality(
model,
model_load,
check_argsort_equality=(vocab_size > 2048),
)

0 comments on commit 3a589b1

Please sign in to comment.