Skip to content

Commit

Permalink
fix(argparse): No more false positives on contained subcommands
Browse files Browse the repository at this point in the history
  • Loading branch information
fkglr committed Feb 23, 2024
1 parent 2e0e84c commit 4b6c065
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ def __init__(
) -> None:
self.model = model
self.args = to_dict(namespace)
self.subcommand_path: tuple[str] = tuple()
self.schema: Dict[str, Any] = self._get_nested_model_fields(self.model, namespace)
self.schema = self._remove_null_leaves(self.schema)

def _get_nested_model_fields(self, model: ModelT, namespace: Namespace, parent: Optional[Tuple] = None):
def contains_subcommand(namespace: Namespace, subcommand: str):
for name, obj in vars(namespace).items():
if isinstance(obj, Namespace):
if name == subcommand:
return True
elif contains_subcommand(obj, subcommand):
return True
def _get_nested_model_fields(self, model: ModelT, namespace: Namespace):
def contains_subcommand(ns: Namespace, subcommand_path: Tuple[str]):
for step in subcommand_path:
ns = getattr(ns, step, None)

return False
if not isinstance(ns, Namespace):
return False

return True

model_fields: Dict[str, Any] = dict()

Expand All @@ -45,28 +45,25 @@ def contains_subcommand(namespace: Namespace, subcommand: str):

if field.is_a(BaseModel):
if field.is_subcommand():
if not contains_subcommand(namespace, key):
sub_command_path = (*self.subcommand_path, key)

if not contains_subcommand(namespace, sub_command_path):
continue

parent = (*parent, key) if parent is not None else (key,)
self.subcommand_path = sub_command_path

# recursively build nestes pydantic models in dict,
# which matches the actual schema the nested
# schema pydantic will be expecting
model_fields[key] = self._get_nested_model_fields(
field.model_type, namespace, parent
field.model_type, namespace
)
else:
# start with all leaves as None unless key is in top level
value = self.args.get(key, None)
if parent is not None:
# however, if travesing nested models, then the parent should
# not be None and then there is potentially a real value to get

# check full path first
# TODO: this may not be needed depending on how nested namespaces work
# since the arg groups are not nested -- just flattened
path = (*parent, key)

if len(self.subcommand_path) > 0:
path = (*self.subcommand_path, key)
value = get_path(self.args, path, value)

model_fields[key] = value
Expand Down

0 comments on commit 4b6c065

Please sign in to comment.