From a564d3e87c0a170564c32919c36cb0bad06e3a5d Mon Sep 17 00:00:00 2001 From: fabian Date: Fri, 5 Jan 2024 17:20:30 +0100 Subject: [PATCH] Additionally return the model for the selected subcommand --- .../pydantic_argparse/argparse/parser.py | 6 +++--- .../pydantic_argparse/utils/nesting.py | 12 +++++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/vendor/pydantic-argparse/pydantic_argparse/argparse/parser.py b/vendor/pydantic-argparse/pydantic_argparse/argparse/parser.py index aa1dc53d..bb67a5e6 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/argparse/parser.py +++ b/vendor/pydantic-argparse/pydantic_argparse/argparse/parser.py @@ -21,7 +21,7 @@ import argparse import sys -from typing import Dict, Generic, List, NoReturn, Optional, Type, cast +from typing import Dict, Generic, List, NoReturn, Optional, Tuple, Type from pydantic import BaseModel, ValidationError @@ -133,7 +133,7 @@ def has_submodels(self) -> bool: # noqa: D102 def parse_typed_args( self, args: Optional[List[str]] = None, - ) -> PydanticModelT: + ) -> Tuple[PydanticModelT, BaseModel]: """Parses command line arguments. If `args` are not supplied by the user, then they are automatically @@ -154,7 +154,7 @@ def parse_typed_args( try: nested_parser = _NestedArgumentParser(model=self.model, namespace=namespace) - return cast(PydanticModelT, nested_parser.validate()) + return nested_parser.validate() except ValidationError as exc: # Catch exceptions, and use the ArgumentParser.error() method # to report it to the user diff --git a/vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py b/vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py index e032d5c2..d6c675e3 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py +++ b/vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py @@ -80,6 +80,12 @@ def _remove_null_leaves(self, schema: Dict[str, Any]): # the schema return remap(schema, visit=lambda p, k, v: v is not None) - def validate(self): - """Return an instance of the `pydantic` modeled validated with data passed from the command line.""" - return self.model.model_validate(self.schema) + def validate(self) -> Tuple[PydanticModelT, BaseModel]: + """Return the root of the model, as well as the sub-model for the bottom subcommand""" + model = self.model.model_validate(self.schema) + subcommand = model + + for step in self.subcommand_path: + subcommand = getattr(subcommand, step) + + return model, subcommand