Skip to content

Commit

Permalink
Updated PET tests
Browse files Browse the repository at this point in the history
  • Loading branch information
abmazitov committed Sep 24, 2024
1 parent 8a51cb2 commit d7ef992
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def test_batch_dicts_compatibility(cutoff):
ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES,
ARCHITECTURAL_HYPERS.USE_LONG_RANGE,
ARCHITECTURAL_HYPERS.K_CUT,
ARCHITECTURAL_HYPERS.N_TARGETS > 1,
ARCHITECTURAL_HYPERS.TARGET_INDEX_KEY
)[0]
ref_batch_dict = {
"x": batch.x,
Expand Down Expand Up @@ -146,6 +148,8 @@ def test_predictions_compatibility(cutoff):
ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES,
ARCHITECTURAL_HYPERS.USE_LONG_RANGE,
ARCHITECTURAL_HYPERS.K_CUT,
ARCHITECTURAL_HYPERS.N_TARGETS > 1,
ARCHITECTURAL_HYPERS.TARGET_INDEX_KEY
)[0]

batch_dict = {
Expand All @@ -161,7 +165,7 @@ def test_predictions_compatibility(cutoff):

pet = model.module.pet

pet_prediction = pet.forward(batch_dict)
pet_prediction = pet.forward(batch_dict)['prediction']

torch.testing.assert_close(
mtm_pet_prediction, pet_prediction.sum(dim=0, keepdim=True)
Expand Down

0 comments on commit d7ef992

Please sign in to comment.