From 844c92efecbed1e26a83382a62d54edc3774959f Mon Sep 17 00:00:00 2001 From: fabian Date: Fri, 23 Feb 2024 16:17:39 +0100 Subject: [PATCH] Intermediate status --- src/gallia/command/base.py | 12 ++--- .../pydantic_argparse/argparse/parser.py | 53 +++++++++++++++---- .../pydantic_argparse/parsers/__init__.py | 2 +- .../pydantic_argparse/parsers/boolean.py | 14 +++-- .../pydantic_argparse/parsers/command.py | 4 +- .../pydantic_argparse/parsers/container.py | 3 +- .../pydantic_argparse/parsers/enum.py | 19 ++----- .../pydantic_argparse/parsers/literal.py | 15 ++---- .../pydantic_argparse/parsers/mapping.py | 3 +- .../pydantic_argparse/parsers/standard.py | 7 ++- .../pydantic_argparse/utils/nesting.py | 3 +- .../pydantic_argparse/utils/pydantic.py | 33 +++++++++--- 12 files changed, 105 insertions(+), 63 deletions(-) diff --git a/src/gallia/command/base.py b/src/gallia/command/base.py index 0c642597..88265f93 100644 --- a/src/gallia/command/base.py +++ b/src/gallia/command/base.py @@ -129,9 +129,7 @@ def __init__(self, parser: ArgumentParser, config: Config = Config()) -> None: def run(self, args: Namespace) -> int: ... - def run_hook( - self, variant: HookVariant, args: Namespace, exit_code: int | None = None - ) -> None: + def run_hook(self, variant: HookVariant, args: Namespace, exit_code: int | None = None) -> None: script = args.pre_hook if variant == HookVariant.PRE else args.post_hook if script is None or script == "": return @@ -154,9 +152,7 @@ def run_hook( if exit_code is not None: env["GALLIA_EXIT_CODE"] = str(exit_code) try: - p = run( - script, env=env, text=True, capture_output=True, shell=True, check=True - ) + p = run(script, env=env, text=True, capture_output=True, shell=True, check=True) stdout = p.stdout stderr = p.stderr except CalledProcessError as e: @@ -344,9 +340,7 @@ def entry_point(self, args: Namespace) -> int: logger.critical(f"Unable to lock {p}: {e}") return exitcode.OSFILE if self.HAS_ARTIFACTS_DIR: - self.artifacts_dir = self.prepare_artifactsdir( - args.artifacts_base, args.artifacts_dir - ) + self.artifacts_dir = self.prepare_artifactsdir(args.artifacts_base, args.artifacts_dir) self.log_file_handlers.append( add_zst_log_handler( logger_name="gallia", diff --git a/vendor/pydantic-argparse/pydantic_argparse/argparse/parser.py b/vendor/pydantic-argparse/pydantic_argparse/argparse/parser.py index bb67a5e6..ca9479f9 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/argparse/parser.py +++ b/vendor/pydantic-argparse/pydantic_argparse/argparse/parser.py @@ -21,7 +21,9 @@ import argparse import sys -from typing import Dict, Generic, List, NoReturn, Optional, Tuple, Type +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Dict, Generic, List, NoReturn, Optional, Tuple, Type, Any from pydantic import BaseModel, ValidationError @@ -31,6 +33,12 @@ from pydantic_argparse.utils.pydantic import PydanticField, PydanticModelT + +@dataclass +class ArgumentGroup: + name: str | None + + class ArgumentParser(argparse.ArgumentParser, Generic[PydanticModelT]): """Declarative and Typed Argument Parser. @@ -66,6 +74,7 @@ def __init__( epilog: Optional[str] = None, add_help: bool = True, exit_on_error: bool = True, + extra_defaults: dict[Type, dict[str, Any]] | None = None ) -> None: """Instantiates the Typed Argument Parser with its `pydantic` model. @@ -102,9 +111,16 @@ def __init__( self.version = version self.add_help = add_help self.exit_on_error = exit_on_error + self.extra_defaults = extra_defaults # Add Arguments Groups self._subcommands: Optional[argparse._SubParsersAction] = None + + # Add Arguments from Model + self._submodels: dict[str, Type[BaseModel]] = dict() + self.model = self._add_model(model) + print(vars(self.model), file=open("/tmp/after", "w")) + self._help_group = self.add_argument_group(ArgumentParser.HELP) # Add Help and Version Flags @@ -113,10 +129,6 @@ def __init__( if self.version: self._add_version_flag() - # Add Arguments from Model - self._submodels: dict[str, Type[BaseModel]] = dict() - self.model = self._add_model(model) - @property def has_submodels(self) -> bool: # noqa: D102 # this is for simple nested models as arg groups @@ -247,11 +259,14 @@ def _add_model( validators: Dict[str, utils.pydantic.PydanticValidator] = dict() parser = self if arg_group is None else arg_group + explicit_groups = {} + validation_model = model.model_construct() + # Loop through fields in model for field in PydanticField.parse_model(model): if field.is_a(BaseModel): if field.is_subcommand(): - validator = parsers.command.parse_field(self._commands(), field) + validator = parsers.command.parse_field(self._commands(), field, self.extra_defaults) else: # for any nested pydantic models, set default factory to model_construct # method. This allows pydantic to handle if no arguments from a nested @@ -261,21 +276,41 @@ def _add_model( field.info.default_factory = field.model_type.model_construct # create new arg group - group_name = str.upper(field.info.title or field.name) + group_name = field.info.title or field.name arg_group = self.add_argument_group(group_name) # recurse and parse fields below this submodel # TODO: storage of submodels not needed self._submodels[field.name] = self._add_model( model=field.model_type, - arg_group=arg_group, + arg_group=arg_group ) validator = None else: # Add field - validator = parsers.add_field(parser, field) + added = False + + if self.extra_defaults is not None and model in self.extra_defaults and field.name in self.extra_defaults[model]: + field.extra_default = self.extra_defaults[model][field.name] + try: + field.validated_extra_default = getattr(model.__pydantic_validator__.validate_assignment(validation_model, field.name, field.extra_default), field.name) + except ValidationError: + # TODO Print warning for invalid config + pass + + for annotation in field.info.metadata: + if isinstance(annotation, ArgumentGroup) and annotation.name is not None: + if annotation.name not in explicit_groups: + explicit_groups[annotation.name] = self.add_argument_group(annotation.name) + + validator = parsers.add_field(explicit_groups[annotation.name], field) + added = True + break + + if not added: + validator = parsers.add_field(parser, field) # Update validators utils.pydantic.update_validators(validators, validator) diff --git a/vendor/pydantic-argparse/pydantic_argparse/parsers/__init__.py b/vendor/pydantic-argparse/pydantic_argparse/parsers/__init__.py index ef0ec8ae..f1a4f8ff 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/parsers/__init__.py +++ b/vendor/pydantic-argparse/pydantic_argparse/parsers/__init__.py @@ -26,7 +26,7 @@ ) from .utils import SupportsAddArgument - +# TODO: The validators do not work for nested models, either fix or remove this functionality def add_field( parser: SupportsAddArgument, field: PydanticField, diff --git a/vendor/pydantic-argparse/pydantic_argparse/parsers/boolean.py b/vendor/pydantic-argparse/pydantic_argparse/parsers/boolean.py index 3336b0a5..b7c34f80 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/parsers/boolean.py +++ b/vendor/pydantic-argparse/pydantic_argparse/parsers/boolean.py @@ -47,12 +47,17 @@ def parse_field( Optional[PydanticValidator]: Possible validator method. """ # Compute Argument Intrinsics - is_inverted = not field.info.is_required() and bool(field.info.get_default()) + invalid_extra_default = field.extra_default is not None and field.validated_extra_default is None + + if field.validated_extra_default is not None: + is_inverted = field.validated_extra_default + else: + is_inverted = not field.info.is_required() and bool(field.info.get_default()) # Determine Argument Properties action = ( actions.BooleanOptionalAction - if field.info.is_required() + if field.arg_required() or invalid_extra_default else argparse._StoreFalseAction if is_inverted else argparse._StoreTrueAction @@ -60,11 +65,12 @@ def parse_field( # Add Boolean Field parser.add_argument( - field.argname(is_inverted), + field.argname(is_inverted and not invalid_extra_default), action=action, help=field.description(), dest=field.name, - required=field.info.is_required(), + required=field.arg_required(), + **field.arg_default() ) # Construct and Return Validator diff --git a/vendor/pydantic-argparse/pydantic_argparse/parsers/command.py b/vendor/pydantic-argparse/pydantic_argparse/parsers/command.py index 7dd71649..f4963c76 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/parsers/command.py +++ b/vendor/pydantic-argparse/pydantic_argparse/parsers/command.py @@ -11,7 +11,7 @@ """ import argparse -from typing import Optional +from typing import Optional, Type, Any from pydantic_argparse.utils.pydantic import ( PydanticField, @@ -35,6 +35,7 @@ def should_parse(field: PydanticField) -> bool: def parse_field( subparser: argparse._SubParsersAction, field: PydanticField, + extra_defaults: dict[Type, dict[str, Any]] | None = None ) -> Optional[PydanticValidator]: """Adds command pydantic field to argument parser. @@ -51,6 +52,7 @@ def parse_field( help=field.info.description, model=field.model_type, # type: ignore[call-arg] exit_on_error=False, # Allow top level parser to handle exiting + extra_defaults=extra_defaults ) # Return diff --git a/vendor/pydantic-argparse/pydantic_argparse/parsers/container.py b/vendor/pydantic-argparse/pydantic_argparse/parsers/container.py index 0150d750..494a7b77 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/parsers/container.py +++ b/vendor/pydantic-argparse/pydantic_argparse/parsers/container.py @@ -56,7 +56,8 @@ def parse_field( help=field.description(), dest=field.name, metavar=field.metavar(), - required=field.info.is_required(), + required=field.arg_required(), + **field.arg_default() ) # Construct and Return Validator diff --git a/vendor/pydantic-argparse/pydantic_argparse/parsers/enum.py b/vendor/pydantic-argparse/pydantic_argparse/parsers/enum.py index 5a7fc031..dc1c642f 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/parsers/enum.py +++ b/vendor/pydantic-argparse/pydantic_argparse/parsers/enum.py @@ -49,30 +49,19 @@ def parse_field( # Extract Enum enum_type = cast(Type[enum.Enum], field.info.annotation) - # Compute Argument Intrinsics - is_flag = len(enum_type) == 1 and not field.info.is_required() - is_inverted = is_flag and field.info.get_default() is not None - # Determine Argument Properties metavar = f"{{{', '.join(e.name for e in enum_type)}}}" - action = argparse._StoreConstAction if is_flag else argparse._StoreAction - const = ( - {} - if not is_flag - else {"const": None} - if is_inverted - else {"const": list(enum_type)[0]} - ) + action = argparse._StoreAction # Add Enum Field parser.add_argument( - field.argname(is_inverted), + field.argname(), action=action, help=field.description(), dest=field.name, metavar=metavar, - required=field.info.is_required(), - **const, # type: ignore[arg-type] + required=field.arg_required(), + **field.arg_default(), ) # Construct and Return Validator diff --git a/vendor/pydantic-argparse/pydantic_argparse/parsers/literal.py b/vendor/pydantic-argparse/pydantic_argparse/parsers/literal.py index 18f485db..13f05713 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/parsers/literal.py +++ b/vendor/pydantic-argparse/pydantic_argparse/parsers/literal.py @@ -54,26 +54,19 @@ def parse_field( # Extract Choices choices = get_args(field.info.annotation) - # Compute Argument Intrinsics - is_flag = len(choices) == 1 and not field.info.is_required() - is_inverted = is_flag and field.info.get_default() is not None - # Determine Argument Properties metavar = f"{{{', '.join(str(c) for c in choices)}}}" - action = argparse._StoreConstAction if is_flag else argparse._StoreAction - const = ( - {} if not is_flag else {"const": None} if is_inverted else {"const": choices[0]} - ) + action = argparse._StoreAction # Add Literal Field parser.add_argument( - field.argname(is_inverted), + field.argname(), action=action, help=field.description(), dest=field.name, metavar=metavar, - required=field.info.is_required(), - **const, # type: ignore[arg-type] + required=field.arg_required(), + **field.arg_default() ) # Construct String Representation Mapping of Choices diff --git a/vendor/pydantic-argparse/pydantic_argparse/parsers/mapping.py b/vendor/pydantic-argparse/pydantic_argparse/parsers/mapping.py index 11326814..034adc3d 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/parsers/mapping.py +++ b/vendor/pydantic-argparse/pydantic_argparse/parsers/mapping.py @@ -54,7 +54,8 @@ def parse_field( help=field.description(), dest=field.name, metavar=field.metavar(), - required=field.info.is_required(), + required=field.arg_required(), + **field.arg_default() ) # Construct and Return Validator diff --git a/vendor/pydantic-argparse/pydantic_argparse/parsers/standard.py b/vendor/pydantic-argparse/pydantic_argparse/parsers/standard.py index 386e7529..6a01f1d1 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/parsers/standard.py +++ b/vendor/pydantic-argparse/pydantic_argparse/parsers/standard.py @@ -34,6 +34,8 @@ def parse_field( Returns: Optional[PydanticValidator]: Possible validator method. """ + + # Add Standard Field parser.add_argument( field.argname(), @@ -41,8 +43,9 @@ def parse_field( help=field.description(), dest=field.name, metavar=field.metavar(), - required=field.info.is_required(), + required=field.arg_required(), + **field.arg_default() ) # Construct and Return Validator - return utils.pydantic.as_validator(field, lambda v: v) + return utils.pydantic.as_validator(field, lambda v: print("HEEEEEEEEEEERE")) diff --git a/vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py b/vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py index d6c675e3..e746c3da 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py +++ b/vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py @@ -20,7 +20,8 @@ class _NestedArgumentParser(Generic[PydanticModelT]): """Parses arbitrarily nested `pydantic` models and inserts values passed at the command line.""" def __init__( - self, model: PydanticModelT | Type[PydanticModelT], namespace: Namespace + self, model: PydanticModelT | Type[PydanticModelT], + namespace: Namespace, ) -> None: self.model = model self.args = to_dict(namespace) diff --git a/vendor/pydantic-argparse/pydantic_argparse/utils/pydantic.py b/vendor/pydantic-argparse/pydantic_argparse/utils/pydantic.py index 6194b76b..d540f747 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/utils/pydantic.py +++ b/vendor/pydantic-argparse/pydantic_argparse/utils/pydantic.py @@ -11,6 +11,7 @@ """ from collections.abc import Container, Mapping +from dataclasses import dataclass from enum import Enum from typing import ( Any, @@ -42,7 +43,8 @@ NoneType = type(None) -class PydanticField(NamedTuple): +@dataclass +class PydanticField: """Simple Pydantic v2.0 field wrapper. Pydantic fields no longer store their name, so this named tuple @@ -53,6 +55,8 @@ class PydanticField(NamedTuple): name: str info: FieldInfo + extra_default: Any = None + validated_extra_default: Any = None @classmethod def parse_model( @@ -235,18 +239,25 @@ def description(self) -> str: str: Standardised description of the argument. """ # Construct Default String - if self.info.is_required(): - default = None - required = "REQUIRED:" - else: + default = "" + + if not self.info.is_required(): _default = self.info.get_default() if isinstance(_default, Enum): _default = _default.name - default = f"(default: {_default})" - required = None + default = f"default: {_default}" + + if self.extra_default is not None: + if len(default) > 0: + default += "; " + default += f"config: {self.extra_default}" + + if len(default) > 0: + default = f" ({default})" # Return Standardised Description String - return " ".join(filter(None, [required, self.info.description, default])) + description = self.info.description if self.info.description is not None else "" + return f"{description}{default}" def metavar(self) -> Optional[str]: """Generate the metavar name for the field. @@ -266,6 +277,12 @@ def metavar(self) -> Optional[str]: return "|".join(t.__name__.upper() for t in field_type) return field_type.__name__.upper() + def arg_required(self): + return self.info.is_required() and self.extra_default is None + + def arg_default(self): + return {} if self.extra_default is None else {'default': self.extra_default} + def as_validator( field: PydanticField,