diff --git a/.github/workflows/test_onnxruntime.yml b/.github/workflows/test_onnxruntime.yml index f173cc6c6b..4893b681a6 100644 --- a/.github/workflows/test_onnxruntime.yml +++ b/.github/workflows/test_onnxruntime.yml @@ -4,9 +4,9 @@ name: ONNX Runtime / Python - Test on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ main ] + branches: [main] concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -22,62 +22,34 @@ jobs: runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v2 - - - name: Free disk space - if: matrix.os == 'ubuntu-20.04' - run: | - df -h - sudo apt-get update - sudo apt-get purge -y '^apache.*' - sudo apt-get purge -y '^imagemagick.*' - sudo apt-get purge -y '^dotnet.*' - sudo apt-get purge -y '^aspnetcore.*' - sudo apt-get purge -y 'php.*' - sudo apt-get purge -y '^temurin.*' - sudo apt-get purge -y '^mysql.*' - sudo apt-get purge -y '^java.*' - sudo apt-get purge -y '^openjdk.*' - sudo apt-get purge -y microsoft-edge-stable google-cloud-cli azure-cli google-chrome-stable firefox powershell mono-devel - df -h - sudo apt-get autoremove -y >/dev/null 2>&1 - sudo apt-get clean - df -h - echo "https://github.com/actions/virtual-environments/issues/709" - sudo rm -rf "$AGENT_TOOLSDIRECTORY" - df -h - echo "remove big /usr/local" - sudo rm -rf "/usr/local/share/boost" - sudo rm -rf /usr/local/lib/android >/dev/null 2>&1 - df -h - echo "remove /usr/share leftovers" - sudo rm -rf /usr/share/dotnet/sdk > /dev/null 2>&1 - sudo rm -rf /usr/share/dotnet/shared > /dev/null 2>&1 - sudo rm -rf /usr/share/swift > /dev/null 2>&1 - df -h - echo "remove other leftovers" - sudo rm -rf /var/lib/mysql > /dev/null 2>&1 - sudo rm -rf /home/runner/.dotnet > /dev/null 2>&1 - sudo rm -rf /home/runneradmin/.dotnet > /dev/null 2>&1 - sudo rm -rf /etc/skel/.dotnet > /dev/null 2>&1 - sudo rm -rf /usr/local/.ghcup > /dev/null 2>&1 - sudo rm -rf /usr/local/aws-cli > /dev/null 2>&1 - sudo rm -rf /usr/local/lib/node_modules > /dev/null 2>&1 - sudo rm -rf /usr/lib/heroku > /dev/null 2>&1 - sudo rm -rf /usr/local/share/chromium > /dev/null 2>&1 - df -h - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - - name: Install dependencies - run: | - pip install .[tests,onnxruntime] - - - name: Test with pytest - working-directory: tests - run: | - pytest -n auto -m "not run_in_series" --durations=0 -vs onnxruntime - pytest -m "run_in_series" --durations=0 onnxruntime + - name: Free Disk Space (Ubuntu) + if: matrix.os == 'ubuntu-20.04' + uses: jlumbroso/free-disk-space@main + with: + tool-cache: false + swap-storage: false + large-packages: false + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install .[tests,onnxruntime] + + - name: Test with pytest (in series) + working-directory: tests + run: | + pytest onnxruntime -m "run_in_series" --durations=0 -vvvv -s + + - name: Test with pytest (in parallel) + working-directory: tests + run: | + pytest onnxruntime -m "not run_in_series" --durations=0 -vvvv -s -n auto diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 455236126b..2d9be2d757 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -509,8 +509,6 @@ def _from_pretrained( if model_save_dir is None: model_save_dir = new_model_save_dir - # Since v1.7.0 decoder with past models have fixed sequence length of 1 - # To keep these models compatible we set this dimension to dynamic onnx_model = onnx.load(str(model_cache_path), load_external_data=False) model_uses_external_data = check_model_uses_external_data(onnx_model) @@ -521,24 +519,47 @@ def _from_pretrained( node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] for node in onnx_model.graph.input } + output_dims = { + node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] + for node in onnx_model.graph.output + } + + override_dims = False + + # Since v1.7.0 decoder with past models have fixed sequence length of 1 + # To keep these models compatible we set this dimension to dynamic if input_dims["input_ids"][1] == 1: input_dims["input_ids"][1] = "sequence_length" - output_dims = { - node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] - for node in onnx_model.graph.output - } output_dims["logits"][1] = "sequence_length" - onnx_model = update_model_dims.update_inputs_outputs_dims(onnx_model, input_dims, output_dims) + override_dims = True + # Since https://github.com/huggingface/optimum/pull/871/ + # changed axis notation/naming during export, we need to update the dims + for dim in input_dims.keys(): + if "past" in dim and input_dims[dim][2] == "past_sequence_length + sequence_length": + input_dims[dim][2] = "past_sequence_length" + override_dims = True + + if override_dims: + # this is kinda dangerous, warning the user is the least we can do + logger.warning( + "The ONNX model was probably exported with an older version of optimum. " + "We are updating the input/output dimensions and overwriting the model file " + "with new dimensions. This is necessary for the model to work correctly with " + "the current version of optimum. If you encounter any issues, please re-export " + "the model with the latest version of optimum for optimal performance." + ) + onnx_model = update_model_dims.update_inputs_outputs_dims(onnx_model, input_dims, output_dims) onnx.save( onnx_model, str(model_cache_path), save_as_external_data=model_uses_external_data, - all_tensors_to_one_file=True, location=model_cache_path.name + "_data", - size_threshold=0, + all_tensors_to_one_file=True, convert_attribute=True, + size_threshold=0, ) + del onnx_model model = ORTModel.load_model( diff --git a/optimum/onnxruntime/quantization.py b/optimum/onnxruntime/quantization.py index d56e301c3c..d93a7a3132 100644 --- a/optimum/onnxruntime/quantization.py +++ b/optimum/onnxruntime/quantization.py @@ -356,62 +356,45 @@ def quantize( ) quantizer_factory = QDQQuantizer if use_qdq else ONNXQuantizer + # TODO: maybe this logic can be moved to a method in the configuration class (get_ort_quantizer_kwargs()) + # that returns the dictionary of arguments to pass to the quantizer factory depending on the ort version + quantizer_kwargs = { + "model": onnx_model, + "static": quantization_config.is_static, + "per_channel": quantization_config.per_channel, + "mode": quantization_config.mode, + "weight_qType": quantization_config.weights_dtype, + "input_qType": quantization_config.activations_dtype, + "tensors_range": calibration_tensors_range, + "reduce_range": quantization_config.reduce_range, + "nodes_to_quantize": quantization_config.nodes_to_quantize, + "nodes_to_exclude": quantization_config.nodes_to_exclude, + "op_types_to_quantize": [ + operator.value if isinstance(operator, ORTQuantizableOperator) else operator + for operator in quantization_config.operators_to_quantize + ], + "extra_options": { + "WeightSymmetric": quantization_config.weights_symmetric, + "ActivationSymmetric": quantization_config.activations_symmetric, + "EnableSubgraph": has_subgraphs, + "ForceSymmetric": quantization_config.activations_symmetric and quantization_config.weights_symmetric, + "AddQDQPairToWeight": quantization_config.qdq_add_pair_to_weight, + "DedicatedQDQPair": quantization_config.qdq_dedicated_pair, + "QDQOpTypePerChannelSupportToAxis": quantization_config.qdq_op_type_per_channel_support_to_axis, + }, + } + + if use_qdq: + quantizer_kwargs.pop("mode") + if parse(ort_version) >= Version("1.18.0"): + # The argument `static` has been removed from the qdq quantizer factory in ORT 1.18 + quantizer_kwargs.pop("static") if parse(ort_version) >= Version("1.13.0"): - # The argument `input_qType` has been changed into `activation_qType` from ORT 1.13 - quantizer = quantizer_factory( - model=onnx_model, - static=quantization_config.is_static, - per_channel=quantization_config.per_channel, - mode=quantization_config.mode, - weight_qType=quantization_config.weights_dtype, - activation_qType=quantization_config.activations_dtype, - tensors_range=calibration_tensors_range, - reduce_range=quantization_config.reduce_range, - nodes_to_quantize=quantization_config.nodes_to_quantize, - nodes_to_exclude=quantization_config.nodes_to_exclude, - op_types_to_quantize=[ - operator.value if isinstance(operator, ORTQuantizableOperator) else operator - for operator in quantization_config.operators_to_quantize - ], - extra_options={ - "WeightSymmetric": quantization_config.weights_symmetric, - "ActivationSymmetric": quantization_config.activations_symmetric, - "EnableSubgraph": has_subgraphs, - "ForceSymmetric": quantization_config.activations_symmetric - and quantization_config.weights_symmetric, - "AddQDQPairToWeight": quantization_config.qdq_add_pair_to_weight, - "DedicatedQDQPair": quantization_config.qdq_dedicated_pair, - "QDQOpTypePerChannelSupportToAxis": quantization_config.qdq_op_type_per_channel_support_to_axis, - }, - ) - else: - quantizer = quantizer_factory( - model=onnx_model, - static=quantization_config.is_static, - per_channel=quantization_config.per_channel, - mode=quantization_config.mode, - weight_qType=quantization_config.weights_dtype, - input_qType=quantization_config.activations_dtype, - tensors_range=calibration_tensors_range, - reduce_range=quantization_config.reduce_range, - nodes_to_quantize=quantization_config.nodes_to_quantize, - nodes_to_exclude=quantization_config.nodes_to_exclude, - op_types_to_quantize=[ - operator.value if isinstance(operator, ORTQuantizableOperator) else operator - for operator in quantization_config.operators_to_quantize - ], - extra_options={ - "WeightSymmetric": quantization_config.weights_symmetric, - "ActivationSymmetric": quantization_config.activations_symmetric, - "EnableSubgraph": False, - "ForceSymmetric": quantization_config.activations_symmetric - and quantization_config.weights_symmetric, - "AddQDQPairToWeight": quantization_config.qdq_add_pair_to_weight, - "DedicatedQDQPair": quantization_config.qdq_dedicated_pair, - "QDQOpTypePerChannelSupportToAxis": quantization_config.qdq_op_type_per_channel_support_to_axis, - }, - ) + # The argument `input_qType` has been changed into `activation_qType` in ORT 1.13 + quantizer_kwargs["activation_qType"] = quantizer_kwargs.pop("input_qType") + + quantizer = quantizer_factory(**quantizer_kwargs) LOGGER.info("Quantizing model...") quantizer.quantize_model() diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 182e64beb9..3fe2c5e14d 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2274,21 +2274,25 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): SPEEDUP_CACHE = 1.1 @parameterized.expand([(False,), (True,)]) + @pytest.mark.run_in_series def test_inference_old_onnx_model(self, use_cache): - model_id = "optimum/gpt2" + tokenizer = get_preprocessor("gpt2") model = AutoModelForCausalLM.from_pretrained("gpt2") - tokenizer = get_preprocessor(model_id) - text = "This is a sample output" - tokens = tokenizer(text, return_tensors="pt") - onnx_model = ORTModelForCausalLM.from_pretrained(model_id, use_cache=use_cache, use_io_binding=use_cache) + onnx_model = ORTModelForCausalLM.from_pretrained("optimum/gpt2", use_cache=use_cache, use_io_binding=use_cache) self.assertEqual(onnx_model.use_cache, use_cache) self.assertEqual(onnx_model.model_path.name, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME) - outputs_onnx = onnx_model.generate( - **tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30 + + text = "The capital of France is" + tokens = tokenizer(text, return_tensors="pt") + + onnx_outputs = onnx_model.generate( + **tokens, num_beams=1, do_sample=False, min_new_tokens=10, max_new_tokens=10 ) - outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30) - self.assertTrue(torch.allclose(outputs_onnx, outputs)) + outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=10, max_new_tokens=10) + onnx_text_outputs = tokenizer.decode(onnx_outputs[0], skip_special_tokens=True) + text_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True) + self.assertEqual(onnx_text_outputs, text_outputs) def test_load_model_from_hub_onnx(self): model = ORTModelForCausalLM.from_pretrained("fxmarty/onnx-tiny-random-gpt2-without-merge") @@ -3596,6 +3600,7 @@ def _get_onnx_model_dir(self, model_id, model_arch, test_name): return onnx_model_dir + @pytest.mark.run_in_series def test_inference_old_onnx_model(self): model = ORTModelForSeq2SeqLM.from_pretrained("optimum/t5-small") diff --git a/tests/onnxruntime/test_stable_diffusion_pipeline.py b/tests/onnxruntime/test_stable_diffusion_pipeline.py index 0e56b22f71..44cd22ffec 100644 --- a/tests/onnxruntime/test_stable_diffusion_pipeline.py +++ b/tests/onnxruntime/test_stable_diffusion_pipeline.py @@ -227,20 +227,18 @@ def test_compare_diffusers_pipeline(self, model_arch: str): model_args = {"test_name": model_arch, "model_arch": model_arch} self._setup(model_args) height, width = 128, 128 - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + inputs = self.generate_inputs(height=height, width=width) inputs["prompt"] = "A painting of a squirrel eating a burger" inputs["image"] = floats_tensor((1, 3, height, width), rng=random.Random(SEED)) - output = pipeline(**inputs, generator=np.random.RandomState(0)).images[0, -3:, -3:, -1] - # https://github.com/huggingface/diffusers/blob/v0.17.1/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py#L71 - expected_slice = np.array([0.69643, 0.58484, 0.50314, 0.58760, 0.55368, 0.59643, 0.51529, 0.41217, 0.49087]) - self.assertTrue(np.allclose(output.flatten(), expected_slice, atol=1e-1)) + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + ort_output = ort_pipeline(**inputs, generator=np.random.RandomState(SEED)).images + + diffusers_onnx_pipeline = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) + diffusers_onnx_output = diffusers_onnx_pipeline(**inputs, generator=np.random.RandomState(SEED)).images - # Verify it can be loaded with ORT diffusers pipeline - diffusers_pipeline = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) - diffusers_output = diffusers_pipeline(**inputs, generator=np.random.RandomState(0)).images[0, -3:, -3:, -1] - self.assertTrue(np.allclose(output, diffusers_output, atol=1e-2)) + self.assertTrue(np.allclose(ort_output, diffusers_onnx_output, atol=1e-1)) def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"): inputs = _generate_inputs(batch_size=batch_size) @@ -418,6 +416,7 @@ def test_compare_diffusers_pipeline(self, model_arch: str): model_args = {"test_name": model_arch, "model_arch": model_arch} self._setup(model_args) ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + diffusers_pipeline = self.ORTMODEL_CLASS.auto_model_class.from_pretrained(MODEL_NAMES[model_arch]) height, width = 64, 64 latents_shape = ( 1, @@ -425,22 +424,18 @@ def test_compare_diffusers_pipeline(self, model_arch: str): height // ort_pipeline.vae_scale_factor, width // ort_pipeline.vae_scale_factor, ) - latents = np.random.randn(*latents_shape).astype(np.float32) inputs = self.generate_inputs(height=height, width=width) - inputs["image"] = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo.png" - ).resize((width, height)) - inputs["mask_image"] = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" - ).resize((width, height)) + np_latents = np.random.rand(*latents_shape).astype(np.float32) + torch_latents = torch.from_numpy(np_latents) + + ort_outputs = ort_pipeline(**inputs, latents=np_latents).images + self.assertEqual(ort_outputs.shape, (1, height, width, 3)) + + diffusers_outputs = diffusers_pipeline(**inputs, latents=torch_latents).images + self.assertEqual(diffusers_outputs.shape, (1, height, width, 3)) - outputs = ort_pipeline(**inputs, latents=latents).images - self.assertEqual(outputs.shape, (1, height, width, 3)) - expected_slice = np.array([0.5442, 0.3002, 0.5665, 0.6485, 0.4421, 0.6441, 0.5778, 0.5076, 0.5612]) - self.assertTrue(np.allclose(outputs[0, -3:, -3:, -1].flatten(), expected_slice, atol=1e-4)) + self.assertTrue(np.allclose(ort_outputs, diffusers_outputs, atol=1e-4)) def generate_inputs(self, height=128, width=128, batch_size=1): inputs = super(ORTStableDiffusionInpaintPipelineTest, self).generate_inputs(height, width)