From 7d6bbfe745fd064d2a85a7ce9a738df939e874af Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Tue, 29 Oct 2024 21:51:15 +0100 Subject: [PATCH] fix: based on PR comments --- src/koheesio/context.py | 7 +++--- src/koheesio/models/__init__.py | 22 +++++++++++++++++-- src/koheesio/models/reader.py | 10 +++++---- src/koheesio/models/sql.py | 5 +---- src/koheesio/spark/__init__.py | 17 +++++++++++++- src/koheesio/spark/readers/memory.py | 6 +++-- .../spark/transformations/sql_transform.py | 6 ++--- src/koheesio/spark/writers/delta/stream.py | 4 +--- tests/spark/readers/test_memory.py | 3 +-- .../date_time/test_interval.py | 2 +- 10 files changed, 57 insertions(+), 25 deletions(-) diff --git a/src/koheesio/context.py b/src/koheesio/context.py index 925ce67..e0b818a 100644 --- a/src/koheesio/context.py +++ b/src/koheesio/context.py @@ -14,9 +14,9 @@ from __future__ import annotations import re -from typing import Any, Dict, Iterator, Union from collections.abc import Mapping from pathlib import Path +from typing import Any, Dict, Iterator, Union import jsonpickle # type: ignore[import-untyped] import tomli @@ -87,8 +87,9 @@ def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] if isinstance(arg, Context): kwargs = kwargs.update(arg.to_dict()) - for key, value in kwargs.items(): - self.__dict__[key] = self.process_value(value) + if kwargs: + for key, value in kwargs.items(): + self.__dict__[key] = self.process_value(value) def __str__(self) -> str: """Returns a string representation of the Context.""" diff --git a/src/koheesio/models/__init__.py b/src/koheesio/models/__init__.py index d0ca34b..1b33e6a 100644 --- a/src/koheesio/models/__init__.py +++ b/src/koheesio/models/__init__.py @@ -9,14 +9,32 @@ Transformation and Reader classes. """ -from typing import Annotated, Any, Dict, List, Optional, Union +from __future__ import annotations + from abc import ABC from functools import cached_property from pathlib import Path +from typing import Annotated, Any, Dict, List, Optional, Union # to ensure that koheesio.models is a drop in replacement for pydantic from pydantic import BaseModel as PydanticBaseModel -from pydantic import * # noqa +from pydantic import ( + BeforeValidator, + ConfigDict, + Field, + InstanceOf, + PositiveInt, + PrivateAttr, + SecretBytes, + SecretStr, + SkipValidation, + conint, + conlist, + constr, + field_serializer, + field_validator, + model_validator, +) # noinspection PyProtectedMember from pydantic._internal._generics import PydanticGenericMetadata diff --git a/src/koheesio/models/reader.py b/src/koheesio/models/reader.py index 4ea9db9..3f35192 100644 --- a/src/koheesio/models/reader.py +++ b/src/koheesio/models/reader.py @@ -2,11 +2,13 @@ Module for the BaseReader class """ -from typing import Optional from abc import ABC, abstractmethod +from typing import Optional, TypeVar from koheesio import Step -from koheesio.spark import DataFrame + +# Define a type variable that can be any type of DataFrame +DataFrameType = TypeVar("DataFrameType") class BaseReader(Step, ABC): @@ -27,7 +29,7 @@ class BaseReader(Step, ABC): """ @property - def df(self) -> Optional[DataFrame]: + def df(self) -> Optional[DataFrameType]: """Shorthand for accessing self.output.df If the output.df is None, .execute() will be run first """ @@ -42,7 +44,7 @@ def execute(self) -> Step.Output: """ pass - def read(self) -> DataFrame: + def read(self) -> DataFrameType: """Read from a Reader without having to call the execute() method directly""" self.execute() return self.output.df diff --git a/src/koheesio/models/sql.py b/src/koheesio/models/sql.py index a2ecce2..f19bc96 100644 --- a/src/koheesio/models/sql.py +++ b/src/koheesio/models/sql.py @@ -1,8 +1,8 @@ """This module contains the base class for SQL steps.""" -from typing import Any, Dict, Optional, Union from abc import ABC from pathlib import Path +from typing import Any, Dict, Optional, Union from koheesio import Step from koheesio.models import ExtraParamsMixin, Field, model_validator @@ -60,9 +60,6 @@ def _validate_sql_and_sql_path(self) -> "SqlBaseStep": @property def query(self) -> str: """Returns the query while performing params replacement""" - # query = self.sql.replace("${", "{") if self.sql else self.sql - # if "{" in query: - # query = query.format(**self.params) if self.sql: query = self.sql diff --git a/src/koheesio/spark/__init__.py b/src/koheesio/spark/__init__.py index 0a3bbca..c72cfb0 100644 --- a/src/koheesio/spark/__init__.py +++ b/src/koheesio/spark/__init__.py @@ -4,8 +4,9 @@ from __future__ import annotations -from typing import Optional +import warnings from abc import ABC +from typing import Optional from pydantic import Field @@ -72,3 +73,17 @@ def _get_active_spark_session(self) -> SparkStep: self.spark = get_active_session() return self + + +def current_timestamp_utc(spark): + warnings.warn( + message=( + "The current_timestamp_utc function has been moved to the koheesio.spark.functions module." + "Import it from there instead. Current import path will be deprecated in the future." + ), + category=DeprecationWarning, + stacklevel=2, + ) + from koheesio.spark.functions import current_timestamp_utc as _current_timestamp_utc + + return _current_timestamp_utc(spark) diff --git a/src/koheesio/spark/readers/memory.py b/src/koheesio/spark/readers/memory.py index 90359dc..7900205 100644 --- a/src/koheesio/spark/readers/memory.py +++ b/src/koheesio/spark/readers/memory.py @@ -3,13 +3,12 @@ """ import json -from typing import Any, Dict, Optional, Union from enum import Enum from functools import partial from io import StringIO +from typing import Any, Dict, Optional, Union import pandas as pd - from pyspark.sql.types import StructType from koheesio.models import ExtraParamsMixin, Field @@ -80,6 +79,9 @@ def _csv(self) -> DataFrame: else: csv_data: str = self.data # type: ignore + if "header" in self.params and self.params["header"] is True: + self.params["header"] = 0 + pandas_df = pd.read_csv(StringIO(csv_data), **self.params) # type: ignore df = self.spark.createDataFrame(pandas_df, schema=self.schema_) # type: ignore diff --git a/src/koheesio/spark/transformations/sql_transform.py b/src/koheesio/spark/transformations/sql_transform.py index b178f3e..030e1d4 100644 --- a/src/koheesio/spark/transformations/sql_transform.py +++ b/src/koheesio/spark/transformations/sql_transform.py @@ -35,9 +35,9 @@ def execute(self) -> Transformation.Output: if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session() and self.df.isStreaming: raise RuntimeError( - """SQL Transform is not supported in remote sessions with streaming dataframes. - See https://issues.apache.org/jira/browse/SPARK-45957 - It is fixed in PySpark 4.0.0""" + "SQL Transform is not supported in remote sessions with streaming dataframes." + "See https://issues.apache.org/jira/browse/SPARK-45957" + "It is fixed in PySpark 4.0.0" ) self.df.createOrReplaceTempView(table_name) diff --git a/src/koheesio/spark/writers/delta/stream.py b/src/koheesio/spark/writers/delta/stream.py index 49877c9..aea03a5 100644 --- a/src/koheesio/spark/writers/delta/stream.py +++ b/src/koheesio/spark/writers/delta/stream.py @@ -2,8 +2,8 @@ This module defines the DeltaTableStreamWriter class, which is used to write streaming dataframes to Delta tables. """ -from typing import Optional from email.policy import default +from typing import Optional from pydantic import Field @@ -32,7 +32,5 @@ class Options(BaseModel): def execute(self) -> DeltaTableWriter.Output: if self.batch_function: self.streaming_query = self.writer.start() - # elif self.streaming and self.is_remote_spark_session: - # self.streaming_query = self.writer.start() else: self.streaming_query = self.writer.toTable(tableName=self.table.table_name) diff --git a/tests/spark/readers/test_memory.py b/tests/spark/readers/test_memory.py index 40fee52..21b5d53 100644 --- a/tests/spark/readers/test_memory.py +++ b/tests/spark/readers/test_memory.py @@ -1,6 +1,5 @@ import pytest from chispa import assert_df_equality - from pyspark.sql.types import StructType from koheesio.spark.readers.memory import DataFormat, InMemoryDataReader @@ -14,7 +13,7 @@ class TestInMemoryDataReader: "data,format,params,expect_filter", [ pytest.param( - "id,string\n1,hello\n2,world", DataFormat.CSV, {"header":0}, "id < 3" + "id,string\n1,hello\n2,world", DataFormat.CSV, {"header":True}, "id < 3" ), pytest.param( b"id,string\n1,hello\n2,world", DataFormat.CSV, {"header":0}, "id < 3" diff --git a/tests/spark/transformations/date_time/test_interval.py b/tests/spark/transformations/date_time/test_interval.py index e3554e1..71208da 100644 --- a/tests/spark/transformations/date_time/test_interval.py +++ b/tests/spark/transformations/date_time/test_interval.py @@ -123,7 +123,7 @@ def test_interval(input_data, column_name, operation, interval, expected, spark) def test_interval_unhappy(spark): with pytest.raises(ValueError): - validate_interval("some random b*llsh*t") # TODO: this should raise an error, but it doesn't in REMOTE mode + validate_interval("some random sym*bol*s") # invalid operation with pytest.raises(ValueError): _ = adjust_time(col("some_col"), "invalid operation", "1 day")