Skip to content

Commit

Permalink
remove multiple output
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Jul 21, 2023
1 parent cfe6239 commit 3254d6e
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@
from optimum.gptq import GPTQQuantizer, load_quantized_model
from optimum.gptq.data import get_dataset
from optimum.utils.testing_utils import require_accelerate, require_autogptq, require_torch_gpu
from transformers.testing_utils import slow


@slow
@require_autogptq
@require_torch_gpu
class GTPQTest(unittest.TestCase):
model_name = "bigscience/bloom-560m"

input_text = "Hello my name is"
EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
EXPECTED_OUTPUT = "Hello my name is John and I am a professional photographer. I"

# this seems a little small considering that we are doing 4bit quant but we have a small model and ww don't quantize the embeddings
EXPECTED_RELATIVE_DIFFERENCE = 1.664253062
Expand Down Expand Up @@ -93,7 +93,7 @@ def check_inference_correctness(self, model):
output_parallel = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

# Get the generation
self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

def test_generate_quality(self):
self.check_inference_correctness(self.quantized_model)
Expand Down

0 comments on commit 3254d6e

Please sign in to comment.