diff --git a/quinn/dataframe_validator.py b/quinn/dataframe_validator.py index 0d7df9c8..3c235834 100644 --- a/quinn/dataframe_validator.py +++ b/quinn/dataframe_validator.py @@ -53,7 +53,7 @@ def validate_presence_of_columns(df: DataFrame, required_col_names: list[str]) - def validate_schema( required_schema: StructType, ignore_nullable: bool = False, - _df: DataFrame = None, + df_to_be_validated: DataFrame = None, ) -> Callable[[Any, Any], Any]: """Function that validate if a given DataFrame has a given StructType as its schema. Implemented as a decorator factory so can be used both as a standalone function or as @@ -64,9 +64,9 @@ def validate_schema( :param ignore_nullable: (Optional) A flag for if nullable fields should be ignored during validation :type ignore_nullable: bool, optional - :param _df: DataFrame to validate, mandatory when called as a function. Not required + :param df_to_be_validated: DataFrame to validate, mandatory when called as a function. Not required when called as a decorator - :type _df: DataFrame + :type df_to_be_validated: DataFrame :raises DataFrameMissingStructFieldError: if any StructFields from the required schema are not included in the DataFrame schema @@ -96,12 +96,12 @@ def wrapper(*args: object, **kwargs: object) -> DataFrame: return dataframe return wrapper - if _df is None: + if df_to_be_validated is None: # This means the function is being used as a decorator return decorator # This means the function is being called directly with a DataFrame - return decorator(lambda: _df)() + return decorator(lambda: df_to_be_validated)() def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str]) -> None: diff --git a/tests/test_dataframe_validator.py b/tests/test_dataframe_validator.py index 9d299392..a71a55f3 100644 --- a/tests/test_dataframe_validator.py +++ b/tests/test_dataframe_validator.py @@ -34,7 +34,7 @@ def it_raises_when_struct_field_is_missing1(): ] ) with pytest.raises(quinn.DataFrameMissingStructFieldError) as excinfo: - quinn.validate_schema(required_schema, _df=source_df) + quinn.validate_schema(required_schema, df_to_be_validated=source_df) current_spark_version = semver.Version.parse(spark.version) spark_330 = semver.Version.parse("3.3.0") @@ -53,7 +53,7 @@ def it_does_nothing_when_the_schema_matches(): StructField("age", LongType(), True), ] ) - quinn.validate_schema(required_schema, _df=source_df) + quinn.validate_schema(required_schema, df_to_be_validated=source_df) def nullable_column_mismatches_are_ignored(): data = [("jose", 1), ("li", 2), ("luisa", 3)] @@ -64,7 +64,7 @@ def nullable_column_mismatches_are_ignored(): StructField("age", LongType(), False), ] ) - quinn.validate_schema(required_schema, ignore_nullable=True, _df=source_df) + quinn.validate_schema(required_schema, ignore_nullable=True, df_to_be_validated=source_df) def describe_validate_absence_of_columns():