Skip to content

Commit

Permalink
refactor(onnxruntime): add subpackage and move commands
Browse files Browse the repository at this point in the history
The onnxruntime commands are moved into a subpackage loader directory.
This subpackage directory is only loaded (and its commands added) when
the onnxruntime is available.
This avoids wrongly indicating that the onnxruntime commands are available
when the package is actually not installed.
  • Loading branch information
dacorvo committed Jun 6, 2024
1 parent 7e1c79f commit 40e4f8d
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 11 deletions.
1 change: 0 additions & 1 deletion optimum/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@
from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand
from .env import EnvironmentCommand
from .export import ExportCommand, ONNXExportCommand, TFLiteExportCommand
from .onnxruntime import ONNXRuntimeCommand, ONNXRuntimeOptimizeCommand, ONNXRuntimeQuantizeCommand
from .optimum_cli import optimum_cli_subcommand
3 changes: 1 addition & 2 deletions optimum/commands/optimum_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@
from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand
from .env import EnvironmentCommand
from .export import ExportCommand
from .onnxruntime import ONNXRuntimeCommand


logger = logging.get_logger()

# The table below contains the optimum-cli root subcommands provided by the optimum package
OPTIMUM_CLI_ROOT_SUBCOMMANDS = [ExportCommand, EnvironmentCommand, ONNXRuntimeCommand]
OPTIMUM_CLI_ROOT_SUBCOMMANDS = [ExportCommand, EnvironmentCommand]

# The table below is dynamically populated when loading subpackages
_OPTIMUM_CLI_SUBCOMMANDS = []
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/subpackage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .commands import ONNXRuntimeCommand
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,3 @@
# limitations under the License.

from .base import ONNXRuntimeCommand
from .optimize import ONNXRuntimeOptimizeCommand
from .quantize import ONNXRuntimeQuantizeCommand
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# limitations under the License.
"""optimum.onnxruntime command-line interface base classes."""

from .. import BaseOptimumCLICommand, CommandInfo
from optimum.commands import BaseOptimumCLICommand, CommandInfo, optimum_cli_subcommand

from .optimize import ONNXRuntimeOptimizeCommand
from .quantize import ONNXRuntimeQuantizeCommand


@optimum_cli_subcommand()
class ONNXRuntimeCommand(BaseOptimumCLICommand):
COMMAND = CommandInfo(
name="onnxruntime",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def parse_args(parser: "ArgumentParser"):
return parse_args_onnxruntime_optimize(parser)

def run(self):
from ...onnxruntime.configuration import AutoOptimizationConfig, ORTConfig
from ...onnxruntime.optimization import ORTOptimizer
from ...configuration import AutoOptimizationConfig, ORTConfig
from ...optimization import ORTOptimizer

if self.args.output == self.args.onnx_model:
raise ValueError("The output directory must be different than the directory hosting the ONNX model.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pathlib import Path
from typing import TYPE_CHECKING

from .. import BaseOptimumCLICommand
from optimum.commands import BaseOptimumCLICommand


if TYPE_CHECKING:
Expand Down Expand Up @@ -69,8 +69,8 @@ def parse_args(parser: "ArgumentParser"):
return parse_args_onnxruntime_quantize(parser)

def run(self):
from ...onnxruntime.configuration import AutoQuantizationConfig, ORTConfig
from ...onnxruntime.quantization import ORTQuantizer
from ...configuration import AutoQuantizationConfig, ORTConfig
from ...quantization import ORTQuantizer

if self.args.output == self.args.onnx_model:
raise ValueError("The output directory must be different than the directory hosting the ONNX model.")
Expand Down
8 changes: 8 additions & 0 deletions optimum/subpackages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import logging
import sys

Expand All @@ -23,6 +24,8 @@
import importlib_metadata
from importlib.util import find_spec, module_from_spec

from .utils import is_onnxruntime_available


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -71,3 +74,8 @@ def load_subpackages():
"""
SUBPACKAGE_LOADER = "subpackage"
load_namespace_modules("optimum", SUBPACKAGE_LOADER)

# Load subpackages from internal modules not explicitly defined as namespace packages
loader_name = "." + SUBPACKAGE_LOADER
if is_onnxruntime_available():
importlib.import_module(loader_name, package="optimum.onnxruntime")

0 comments on commit 40e4f8d

Please sign in to comment.