Skip to content

Commit

Permalink
fix: based on PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mikita-sakalouski committed Oct 29, 2024
1 parent e324ce2 commit 7d6bbfe
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 25 deletions.
7 changes: 4 additions & 3 deletions src/koheesio/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from __future__ import annotations

import re
from typing import Any, Dict, Iterator, Union
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Dict, Iterator, Union

import jsonpickle # type: ignore[import-untyped]
import tomli
Expand Down Expand Up @@ -87,8 +87,9 @@ def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def]
if isinstance(arg, Context):
kwargs = kwargs.update(arg.to_dict())

for key, value in kwargs.items():
self.__dict__[key] = self.process_value(value)
if kwargs:
for key, value in kwargs.items():
self.__dict__[key] = self.process_value(value)

def __str__(self) -> str:
"""Returns a string representation of the Context."""
Expand Down
22 changes: 20 additions & 2 deletions src/koheesio/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,32 @@
Transformation and Reader classes.
"""

from typing import Annotated, Any, Dict, List, Optional, Union
from __future__ import annotations

from abc import ABC
from functools import cached_property
from pathlib import Path
from typing import Annotated, Any, Dict, List, Optional, Union

# to ensure that koheesio.models is a drop in replacement for pydantic
from pydantic import BaseModel as PydanticBaseModel
from pydantic import * # noqa
from pydantic import (
BeforeValidator,
ConfigDict,
Field,
InstanceOf,
PositiveInt,
PrivateAttr,
SecretBytes,
SecretStr,
SkipValidation,
conint,
conlist,
constr,
field_serializer,
field_validator,
model_validator,
)

# noinspection PyProtectedMember
from pydantic._internal._generics import PydanticGenericMetadata
Expand Down
10 changes: 6 additions & 4 deletions src/koheesio/models/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
Module for the BaseReader class
"""

from typing import Optional
from abc import ABC, abstractmethod
from typing import Optional, TypeVar

from koheesio import Step
from koheesio.spark import DataFrame

# Define a type variable that can be any type of DataFrame
DataFrameType = TypeVar("DataFrameType")


class BaseReader(Step, ABC):
Expand All @@ -27,7 +29,7 @@ class BaseReader(Step, ABC):
"""

@property
def df(self) -> Optional[DataFrame]:
def df(self) -> Optional[DataFrameType]:
"""Shorthand for accessing self.output.df
If the output.df is None, .execute() will be run first
"""
Expand All @@ -42,7 +44,7 @@ def execute(self) -> Step.Output:
"""
pass

def read(self) -> DataFrame:
def read(self) -> DataFrameType:
"""Read from a Reader without having to call the execute() method directly"""
self.execute()
return self.output.df
5 changes: 1 addition & 4 deletions src/koheesio/models/sql.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""This module contains the base class for SQL steps."""

from typing import Any, Dict, Optional, Union
from abc import ABC
from pathlib import Path
from typing import Any, Dict, Optional, Union

from koheesio import Step
from koheesio.models import ExtraParamsMixin, Field, model_validator
Expand Down Expand Up @@ -60,9 +60,6 @@ def _validate_sql_and_sql_path(self) -> "SqlBaseStep":
@property
def query(self) -> str:
"""Returns the query while performing params replacement"""
# query = self.sql.replace("${", "{") if self.sql else self.sql
# if "{" in query:
# query = query.format(**self.params)

if self.sql:
query = self.sql
Expand Down
17 changes: 16 additions & 1 deletion src/koheesio/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

from __future__ import annotations

from typing import Optional
import warnings
from abc import ABC
from typing import Optional

from pydantic import Field

Expand Down Expand Up @@ -72,3 +73,17 @@ def _get_active_spark_session(self) -> SparkStep:

self.spark = get_active_session()
return self


def current_timestamp_utc(spark):
warnings.warn(
message=(
"The current_timestamp_utc function has been moved to the koheesio.spark.functions module."
"Import it from there instead. Current import path will be deprecated in the future."
),
category=DeprecationWarning,
stacklevel=2,
)
from koheesio.spark.functions import current_timestamp_utc as _current_timestamp_utc

return _current_timestamp_utc(spark)
6 changes: 4 additions & 2 deletions src/koheesio/spark/readers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
"""

import json
from typing import Any, Dict, Optional, Union
from enum import Enum
from functools import partial
from io import StringIO
from typing import Any, Dict, Optional, Union

import pandas as pd

from pyspark.sql.types import StructType

from koheesio.models import ExtraParamsMixin, Field
Expand Down Expand Up @@ -80,6 +79,9 @@ def _csv(self) -> DataFrame:
else:
csv_data: str = self.data # type: ignore

if "header" in self.params and self.params["header"] is True:
self.params["header"] = 0

pandas_df = pd.read_csv(StringIO(csv_data), **self.params) # type: ignore
df = self.spark.createDataFrame(pandas_df, schema=self.schema_) # type: ignore

Expand Down
6 changes: 3 additions & 3 deletions src/koheesio/spark/transformations/sql_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def execute(self) -> Transformation.Output:

if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session() and self.df.isStreaming:
raise RuntimeError(
"""SQL Transform is not supported in remote sessions with streaming dataframes.
See https://issues.apache.org/jira/browse/SPARK-45957
It is fixed in PySpark 4.0.0"""
"SQL Transform is not supported in remote sessions with streaming dataframes."
"See https://issues.apache.org/jira/browse/SPARK-45957"
"It is fixed in PySpark 4.0.0"
)

self.df.createOrReplaceTempView(table_name)
Expand Down
4 changes: 1 addition & 3 deletions src/koheesio/spark/writers/delta/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
This module defines the DeltaTableStreamWriter class, which is used to write streaming dataframes to Delta tables.
"""

from typing import Optional
from email.policy import default
from typing import Optional

from pydantic import Field

Expand Down Expand Up @@ -32,7 +32,5 @@ class Options(BaseModel):
def execute(self) -> DeltaTableWriter.Output:
if self.batch_function:
self.streaming_query = self.writer.start()
# elif self.streaming and self.is_remote_spark_session:
# self.streaming_query = self.writer.start()
else:
self.streaming_query = self.writer.toTable(tableName=self.table.table_name)
3 changes: 1 addition & 2 deletions tests/spark/readers/test_memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
from chispa import assert_df_equality

from pyspark.sql.types import StructType

from koheesio.spark.readers.memory import DataFormat, InMemoryDataReader
Expand All @@ -14,7 +13,7 @@ class TestInMemoryDataReader:
"data,format,params,expect_filter",
[
pytest.param(
"id,string\n1,hello\n2,world", DataFormat.CSV, {"header":0}, "id < 3"
"id,string\n1,hello\n2,world", DataFormat.CSV, {"header":True}, "id < 3"
),
pytest.param(
b"id,string\n1,hello\n2,world", DataFormat.CSV, {"header":0}, "id < 3"
Expand Down
2 changes: 1 addition & 1 deletion tests/spark/transformations/date_time/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_interval(input_data, column_name, operation, interval, expected, spark)

def test_interval_unhappy(spark):
with pytest.raises(ValueError):
validate_interval("some random b*llsh*t") # TODO: this should raise an error, but it doesn't in REMOTE mode
validate_interval("some random sym*bol*s")
# invalid operation
with pytest.raises(ValueError):
_ = adjust_time(col("some_col"), "invalid operation", "1 day")
Expand Down

0 comments on commit 7d6bbfe

Please sign in to comment.