Skip to content

Commit

Permalink
enable determinism
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed May 28, 2024
1 parent 0b9409d commit 251272a
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch
from parameterized import parameterized
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, set_seed
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, enable_full_determinism, set_seed
from transformers.testing_utils import slow

from optimum.gptq import GPTQQuantizer, load_quantized_model
Expand Down Expand Up @@ -75,7 +75,9 @@ def setUpClass(cls):
Setup quantized model
"""

set_seed(42)
enable_full_determinism()
set_seed(42, deterministic=True)

cls.model_fp16 = AutoModelForCausalLM.from_pretrained(
cls.model_name, torch_dtype=torch.float16, device_map=cls.device_map_for_quantization
)
Expand Down Expand Up @@ -130,14 +132,13 @@ def check_inference_correctness(self, model):
Given that we are operating on small numbers + the testing model is relatively small, we might not get
the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
"""
set_seed(42)
enable_full_determinism()
set_seed(42, deterministic=True)

input_ids = self.tokenizer(self.input_text, return_tensors="pt").input_ids.to(self.device_for_inference)
output_ids = model.generate(input_ids, do_sample=False, min_new_tokens=10, max_new_tokens=10)
output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)

# TODO: use pytest features to show what we're comparing
# Check the exactness of the result
print(output_text)
self.assertIn(output_text, self.EXPECTED_OUTPUTS)

def test_generate_quality(self):
Expand Down

0 comments on commit 251272a

Please sign in to comment.