Skip to content

Commit

Permalink
Additionally return the model for the selected subcommand
Browse files Browse the repository at this point in the history
  • Loading branch information
fkglr committed Feb 23, 2024
1 parent 4b6c065 commit a564d3e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
6 changes: 3 additions & 3 deletions vendor/pydantic-argparse/pydantic_argparse/argparse/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import argparse
import sys
from typing import Dict, Generic, List, NoReturn, Optional, Type, cast
from typing import Dict, Generic, List, NoReturn, Optional, Tuple, Type

from pydantic import BaseModel, ValidationError

Expand Down Expand Up @@ -133,7 +133,7 @@ def has_submodels(self) -> bool: # noqa: D102
def parse_typed_args(
self,
args: Optional[List[str]] = None,
) -> PydanticModelT:
) -> Tuple[PydanticModelT, BaseModel]:
"""Parses command line arguments.
If `args` are not supplied by the user, then they are automatically
Expand All @@ -154,7 +154,7 @@ def parse_typed_args(

try:
nested_parser = _NestedArgumentParser(model=self.model, namespace=namespace)
return cast(PydanticModelT, nested_parser.validate())
return nested_parser.validate()
except ValidationError as exc:
# Catch exceptions, and use the ArgumentParser.error() method
# to report it to the user
Expand Down
12 changes: 9 additions & 3 deletions vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def _remove_null_leaves(self, schema: Dict[str, Any]):
# the schema
return remap(schema, visit=lambda p, k, v: v is not None)

def validate(self):
"""Return an instance of the `pydantic` modeled validated with data passed from the command line."""
return self.model.model_validate(self.schema)
def validate(self) -> Tuple[PydanticModelT, BaseModel]:
"""Return the root of the model, as well as the sub-model for the bottom subcommand"""
model = self.model.model_validate(self.schema)
subcommand = model

for step in self.subcommand_path:
subcommand = getattr(subcommand, step)

return model, subcommand

0 comments on commit a564d3e

Please sign in to comment.