diff --git a/pyproject.toml b/pyproject.toml index a99d431..be0a549 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ authors = [ # TODO: add other contributors { name = "Danny Meijer", email = "danny.meijer@nike.com" }, { name = "Mikita Sakalouski", email = "mikita.sakalouski@nike.com" }, + { name = "Maxim Mityutko", email = "maxim.mityutko@nike.com" }, ] classifiers = [ "Development Status :: 5 - Production/Stable", @@ -63,6 +64,11 @@ se = ["spark-expectations>=2.1.0"] sftp = ["paramiko>=2.6.0"] delta = ["delta-spark>=2.2"] excel = ["openpyxl>=3.0.0"] +# Tableau dependencies +tableau = [ + "tableauhyperapi>=0.0.19484", + "tableauserverclient>=0.25", +] dev = ["black", "isort", "ruff", "mypy", "pylint", "colorama", "types-PyYAML"] test = [ "chispa", @@ -180,6 +186,7 @@ features = [ "excel", "se", "box", + "tableau", "dev", ] @@ -244,6 +251,7 @@ features = [ "sftp", "delta", "excel", + "tableau", "dev", "test", ] @@ -399,6 +407,7 @@ features = [ "sftp", "delta", "excel", + "tableau", "dev", "test", "docs", diff --git a/src/koheesio/integrations/spark/tableau/__init__.py b/src/koheesio/integrations/spark/tableau/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/koheesio/integrations/spark/tableau/hyper.py b/src/koheesio/integrations/spark/tableau/hyper.py new file mode 100644 index 0000000..b5f4f5c --- /dev/null +++ b/src/koheesio/integrations/spark/tableau/hyper.py @@ -0,0 +1,431 @@ +import os +from typing import Any, List, Optional, Union +from abc import ABC, abstractmethod +from pathlib import PurePath +from tempfile import TemporaryDirectory + +from tableauhyperapi import ( + NOT_NULLABLE, + NULLABLE, + Connection, + CreateMode, + HyperProcess, + Inserter, + SqlType, + TableDefinition, + TableName, + Telemetry, +) + +from pydantic import Field, conlist + +from pyspark.sql import DataFrame +from pyspark.sql.functions import col +from pyspark.sql.types import ( + BooleanType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + ShortType, + StringType, + StructField, + StructType, + TimestampType, +) + +from koheesio.spark.readers import SparkStep +from koheesio.spark.transformations.cast_to_datatype import CastToDatatype +from koheesio.spark.utils import spark_minor_version +from koheesio.steps import Step, StepOutput + + +class HyperFile(Step, ABC): + """ + Base class for all HyperFile classes + """ + + schema_: str = Field(default="Extract", alias="schema", description="Internal schema name within the Hyper file") + table: str = Field(default="Extract", description="Table name within the Hyper file") + + @property + def table_name(self) -> TableName: + """ + Return TableName object for the Hyper file TableDefinition. + """ + return TableName(self.schema_, self.table) + + +class HyperFileReader(HyperFile, SparkStep): + """ + Read a Hyper file and return a Spark DataFrame. + + Examples + -------- + ```python + df = HyperFileReader( + path=PurePath(hw.hyper_path), + ).execute().df + ``` + """ + + path: PurePath = Field( + default=..., description="Path to the Hyper file", examples=["PurePath(~/data/my-file.hyper)"] + ) + + def execute(self): + type_mapping = { + "date": StringType, + "text": StringType, + "double": FloatType, + "bool": BooleanType, + "small_int": ShortType, + "big_int": LongType, + "timestamp": StringType, + "timestamp_tz": StringType, + "int": IntegerType, + "numeric": DecimalType, + } + df_cols = [] + timestamp_cols = [] + date_cols = [] + + with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: + with Connection(endpoint=hp.endpoint, database=self.path) as connection: + table_definition = connection.catalog.get_table_definition(name=self.table_name) + + select_cols = [] + self.log.debug(f"Schema for {self.table_name} in {self.path}:") + for column in table_definition.columns: + self.log.debug(f"|-- {column.name}: {column.type} (nullable = {column.nullability})") + + column_name = column.name.unescaped.__str__() + tableau_type = column.type.__str__().lower() + + if tableau_type.startswith("numeric"): + spark_type = DecimalType(precision=18, scale=5) + else: + spark_type = type_mapping.get(tableau_type, StringType)() + + if tableau_type == "timestamp" or tableau_type == "timestamp_tz": + timestamp_cols.append(column_name) + _col = f'cast("{column_name}" as text)' + elif tableau_type == "date": + date_cols.append(column_name) + _col = f'cast("{column_name}" as text)' + elif tableau_type.startswith("numeric"): + _col = f'cast("{column_name}" as decimal(18,5))' + else: + _col = f'"{column_name}"' + + df_cols.append(StructField(column_name, spark_type)) + select_cols.append(_col) + + data = connection.execute_list_query(f"select {','.join(select_cols)} from {self.table_name}") + + df_schema = StructType(df_cols) + df = self.spark.createDataFrame(data, schema=df_schema) + if timestamp_cols: + df = CastToDatatype(column=timestamp_cols, datatype="timestamp").transform(df) + if date_cols: + df = CastToDatatype(column=date_cols, datatype="date").transform(df) + + self.output.df = df + + +class HyperFileWriter(HyperFile): + """ + Base class for all HyperFileWriter classes + """ + + path: PurePath = Field( + default=TemporaryDirectory().name, description="Path to the Hyper file", examples=["PurePath(/tmp/hyper/)"] + ) + name: str = Field(default="extract", description="Name of the Hyper file") + table_definition: TableDefinition = Field( + default=None, + description="Table definition to write to the Hyper file as described in " + "https://tableau.github.io/hyper-db/lang_docs/py/tableauhyperapi.html#tableauhyperapi.TableDefinition", + ) + + class Output(StepOutput): + """ + Output class for HyperFileListWriter + """ + + hyper_path: PurePath = Field(default=..., description="Path to created Hyper file") + + @property + def hyper_path(self) -> Connection: + """ + Return full path to the Hyper file. + """ + if not os.path.exists(self.path): + os.makedirs(self.path) + + hyper_path = PurePath(self.path, f"{self.name}.hyper" if ".hyper" not in self.name else self.name) + self.log.info(f"Destination file: {hyper_path}") + return hyper_path + + def write(self): + self.execute() + + @abstractmethod + def execute(self): + pass + + +class HyperFileListWriter(HyperFileWriter): + """ + Write list of rows to a Hyper file. + + Reference + --------- + Datatypes in https://tableau.github.io/hyper-db/docs/sql/datatype/ for supported data types. + + Examples + -------- + ```python + hw = HyperFileListWriter( + name="test", + table_definition=TableDefinition( + table_name=TableName("Extract", "Extract"), + columns=[ + TableDefinition.Column(name="string", type=SqlType.text(), nullability=NOT_NULLABLE), + TableDefinition.Column(name="int", type=SqlType.int(), nullability=NULLABLE), + TableDefinition.Column(name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE), + ] + ), + data=[ + ["text_1", 1, datetime(2024, 1, 1, 0, 0, 0, 0)], + ["text_2", 2, datetime(2024, 1, 2, 0, 0, 0, 0)], + ["text_3", None, None], + ], + ).execute() + + # do somthing with returned file path + hw.hyper_path + ``` + """ + + data: conlist(List[Any], min_length=1) = Field(default=..., description="List of rows to write to the Hyper file") + + def execute(self): + with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: + with Connection( + endpoint=hp.endpoint, database=self.hyper_path, create_mode=CreateMode.CREATE_AND_REPLACE + ) as connection: + connection.catalog.create_schema(schema=self.table_definition.table_name.schema_name) + connection.catalog.create_table(table_definition=self.table_definition) + with Inserter(connection, self.table_definition) as inserter: + inserter.add_rows(rows=self.data) + inserter.execute() + + self.output.hyper_path = self.hyper_path + + +class HyperFileParquetWriter(HyperFileWriter): + """ + Read one or multiple parquet files and write them to a Hyper file. + + Notes + ----- + This method is much faster than HyperFileListWriter for large files. + + References + ---------- + Copy from external format: https://tableau.github.io/hyper-db/docs/sql/command/copy_from + Datatypes in https://tableau.github.io/hyper-db/docs/sql/datatype/ for supported data types. + Parquet format limitations: + https://tableau.github.io/hyper-db/docs/sql/external/formats/#external-format-parquet + + Examples + -------- + ```python + hw = HyperFileParquetWriter( + name="test", + table_definition=TableDefinition( + table_name=TableName("Extract", "Extract"), + columns=[ + TableDefinition.Column(name="string", type=SqlType.text(), nullability=NOT_NULLABLE), + TableDefinition.Column(name="int", type=SqlType.int(), nullability=NULLABLE), + TableDefinition.Column(name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE), + ] + ), + files=["/my-path/parquet-1.snappy.parquet","/my-path/parquet-2.snappy.parquet"] + ).execute() + + # do somthing with returned file path + hw.hyper_path + ``` + """ + + file: conlist(Union[str, PurePath], min_length=1) = Field( + default=..., alias="files", description="One or multiple parquet files to write to the Hyper file" + ) + + def execute(self): + _file = [str(f) for f in self.file] + array_files = "'" + "','".join(_file) + "'" + + with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: + with Connection( + endpoint=hp.endpoint, database=self.hyper_path, create_mode=CreateMode.CREATE_AND_REPLACE + ) as connection: + connection.catalog.create_schema(schema=self.table_definition.table_name.schema_name) + connection.catalog.create_table(table_definition=self.table_definition) + sql = f'copy "{self.schema_}"."{self.table}" ' f"from array [{array_files}] " f"with (format parquet)" + self.log.debug(f"Executing SQL: {sql}") + connection.execute_command(sql) + + self.output.hyper_path = self.hyper_path + + +class HyperFileDataFrameWriter(HyperFileWriter): + """ + Write a Spark DataFrame to a Hyper file. + The process will write the DataFrame to a parquet file and then use the HyperFileParquetWriter to write to the + Hyper file. + + Examples + -------- + ```python + hw = HyperFileDataFrameWriter( + df=spark.createDataFrame([(1, "foo"), (2, "bar")], ["id", "name"]), + name="test", + ).execute() + + # do somthing with returned file path + hw.hyper_path + ``` + """ + df: DataFrame = Field(default=..., description="Spark DataFrame to write to the Hyper file") + table_definition: Optional[TableDefinition] = None # table_definition is not required for this class + + @staticmethod + def table_definition_column(column: StructField) -> TableDefinition.Column: + """ + Convert a Spark StructField to a Tableau Hyper SqlType + """ + type_mapping = { + IntegerType(): SqlType.int, + LongType(): SqlType.big_int, + ShortType(): SqlType.small_int, + DoubleType(): SqlType.double, + FloatType(): SqlType.double, + BooleanType(): SqlType.bool, + DateType(): SqlType.date, + StringType(): SqlType.text, + } + + # Handling the TimestampNTZType for Spark 3.4+ + # Mapping both TimestampType and TimestampNTZType to NTZ type of Hyper + if spark_minor_version >= 3.4: + from pyspark.sql.types import TimestampNTZType + + type_mapping[TimestampNTZType()] = SqlType.timestamp + type_mapping[TimestampType()] = SqlType.timestamp + # In older versions of Spark, only TimestampType is available and is mapped to TZ type of Hyper + else: + type_mapping[TimestampType()] = SqlType.timestamp_tz + + if column.dataType in type_mapping: + sql_type = type_mapping[column.dataType]() + elif str(column.dataType).startswith("DecimalType"): + # Tableau Hyper API limits the precision to 18 decimal places + # noinspection PyUnresolvedReferences + sql_type = SqlType.numeric( + precision=column.dataType.precision if column.dataType.precision <= 18 else 18, + scale=column.dataType.scale, + ) + else: + raise ValueError(f"Unsupported datatype '{column.dataType}' for column '{column.name}'.") + + return TableDefinition.Column( + name=column.name, type=sql_type, nullability=NULLABLE if column.nullable else NOT_NULLABLE + ) + + @property + def _table_definition(self) -> TableDefinition: + schema = self.df.schema + columns = list(map(self.table_definition_column, schema)) + + td = TableDefinition(table_name=self.table_name, columns=columns) + self.log.debug(f"Table definition for {self.table_name}:") + for column in td.columns: + self.log.debug(f"|-- {column.name}: {column.type} (nullable = {column.nullability})") + + return td + + def clean_dataframe(self) -> DataFrame: + """ + - Replace NULLs for string and numeric columns + - Convert data types to ensure compatibility with Tableau Hyper API + """ + _df = self.df + _schema = self.df.schema + + integer_cols = [field.name for field in _schema if field.dataType == IntegerType()] + long_cols = [field.name for field in _schema if field.dataType == LongType()] + double_cols = [field.name for field in _schema if field.dataType == DoubleType()] + float_cols = [field.name for field in _schema if field.dataType == FloatType()] + string_cols = [field.name for field in _schema if field.dataType == StringType()] + decimal_cols = [field for field in _schema if str(field.dataType).startswith("DecimalType")] + timestamp_cols = [field.name for field in _schema if field.dataType == TimestampType()] + + # Cast decimal fields to DecimalType(18,3) + for d_col in decimal_cols: + # noinspection PyUnresolvedReferences + if d_col.dataType.precision > 18: + _df = self.df.withColumn(d_col.name, col(d_col.name).cast(DecimalType(precision=18, scale=5))) + + # Handling the TimestampNTZType for Spark 3.4+ + # Any TimestampType column will be cast to TimestampNTZType for compatibility with Tableau Hyper API + if spark_minor_version >= 3.4: + from pyspark.sql.types import TimestampNTZType + + for t_col in timestamp_cols: + _df = _df.withColumn(t_col, col(t_col).cast(TimestampNTZType())) + + # Replace null and NaN values with 0 + if len(integer_cols) > 0: + _df = _df.na.fill(0, integer_cols) + elif len(long_cols) > 0: + _df = _df.na.fill(0, long_cols) + elif len(double_cols) > 0: + _df = _df.na.fill(0.0, double_cols) + elif len(float_cols) > 0: + _df = _df.na.fill(0.0, float_cols) + elif len(decimal_cols) > 0: + _df = _df.na.fill(0.0, decimal_cols) + elif len(string_cols) > 0: + _df = _df.na.fill("", string_cols) + + return _df + + def write_parquet(self): + _path = self.path.joinpath("parquet") + ( + self.clean_dataframe() + .coalesce(1) + .write.option("delimiter", ",") + .option("header", "true") + .mode("overwrite") + .parquet(_path.as_posix()) + ) + + for _, _, files in os.walk(_path): + for file in files: + if file.endswith(".parquet"): + fp = PurePath(_path, file) + self.log.info("Parquet file created: %s", fp) + return [fp] + + def execute(self): + w = HyperFileParquetWriter( + path=self.path, name=self.name, table_definition=self._table_definition, files=self.write_parquet() + ) + w.execute() + self.output.hyper_path = w.output.hyper_path diff --git a/src/koheesio/integrations/spark/tableau/server.py b/src/koheesio/integrations/spark/tableau/server.py new file mode 100644 index 0000000..023305d --- /dev/null +++ b/src/koheesio/integrations/spark/tableau/server.py @@ -0,0 +1,229 @@ +import os +from typing import ContextManager, Optional, Union +from enum import Enum +from pathlib import PurePath + +import urllib3 +from tableauserverclient import ( + DatasourceItem, + Pager, + PersonalAccessTokenAuth, + ProjectItem, + Server, + TableauAuth, +) + +from pydantic import Field, SecretStr + +from koheesio.models import model_validator +from koheesio.steps import Step, StepOutput + + +class TableauServer(Step): + """ + Base class for Tableau server interactions. Class provides authentication and project identification functionality. + """ + url: str = Field( + default=..., + alias="url", + description="Hostname for the Tableau server, e.g. tableau.my-org.com", + examples=["tableau.my-org.com"], + ) + user: str = Field(default=..., alias="user", description="Login name for the Tableau user") + password: SecretStr = Field(default=..., alias="password", description="Password for the Tableau user") + site_id: str = Field( + default=..., + alias="site_id", + description="Identifier for the Tableau site, as used in the URL: https://tableau.my-org.com/#/site/SITE_ID", + ) + version: str = Field( + default="3.14", + alias="version", + description="Version of the Tableau server API", + ) + token_name: Optional[str] = Field( + default=None, + alias="token_name", + description="Name of the Tableau Personal Access Token", + ) + token_value: Optional[SecretStr] = Field( + default=None, + alias="token_value", + description="Value of the Tableau Personal Access Token", + ) + project: Optional[str] = Field( + default=None, + alias="project", + description="Name of the project on the Tableau server", + ) + parent_project: Optional[str] = Field( + default=None, + alias="parent_project", + description="Name of the parent project on the Tableau server, use 'root' for the root project.", + ) + project_id: Optional[str] = Field( + default=None, + alias="project_id", + description="ID of the project on the Tableau server", + ) + + def __init__(self, **data): + super().__init__(**data) + self.server = None + + @model_validator(mode="after") + def validate_project(cls, data: dict) -> dict: + """Validate when project and project_id are provided at the same time.""" + project = data.get("project") + project_id = data.get("project_id") + + if project and project_id: + raise ValueError("Both 'project' and 'project_id' parameters cannot be provided at the same time.") + + if not project and not project_id: + raise ValueError("Either 'project' or 'project_id' parameters should be provided, none is set") + + @property + def auth(self) -> ContextManager: + """ + Authenticate on the Tableau server. + + Examples + -------- + ```python + with self._authenticate(): + self.server.projects.get() + ``` + + Returns + ------- + ContextManager for TableauAuth or PersonalAccessTokenAuth authorization object + """ + # Suppress 'InsecureRequestWarning' + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + tableau_auth = TableauAuth(username=self.user, password=self.password.get_secret_value(), site_id=self.site_id) + + if self.token_name and self.token_value: + self.log.info( + "Token details provided, this will take precedence over username and password authentication." + ) + tableau_auth = PersonalAccessTokenAuth( + token_name=self.token_name, personal_access_token=self.token_value, site_id=self.site_id + ) + + self.server = Server(self.url) + self.server.version = self.version + self.server.add_http_options({"verify": False}) + + return self.server.auth.sign_in(tableau_auth) + + @property + def working_project(self) -> Union[ProjectItem, None]: + """ + Identify working project by using `project` and `parent_project` (if necessary) class properties. + The goal is to uniquely identify specific project on the server. If multiple projects have the same + name, the `parent_project` attribute of the TableauServer is required. + + Notes + ----- + Set `parent_project` value to 'root' if the project is located in the root directory. + + If `id` of the project is known, it can be used in `project_id` parameter, then the detection of the working + project using the `project` and `parent_project` attributes is skipped. + + Returns + ------- + ProjectItem object representing the working project + """ + + with self.auth: + all_projects = Pager(self.server.projects) + parent, lim_p = None, [] + + for project in all_projects: + if project.id == self.project_id: + lim_p = [project] + self.log.info(f"\nProject ID provided directly:\n\tName: {lim_p[0].name}\n\tID: {lim_p[0].id}") + break + + # Identify parent project + if project.name.strip() == self.parent_project and not self.project_id: + parent = project + self.log.info(f"\nParent project identified:\n\tName: {parent.name}\n\tID: {parent.id}") + + # Identify project(s) + if project.name.strip() == self.project and not self.project_id: + lim_p.append(project) + + # Further filter the list of projects by parent project id + if self.parent_project == "root" and not self.project_id: + lim_p = [p for p in lim_p if not p.parent_id] + elif self.parent_project and parent and not self.project_id: + lim_p = [p for p in lim_p if p.parent_id == parent.id] + + if len(lim_p) > 1: + raise ValueError( + "Multiple projects with the same name exist on the server, " + "please provide `parent_project` attribute." + ) + elif len(lim_p) == 0: + raise ValueError("Working project could not be identified.") + else: + self.log.info(f"\nWorking project identified:\n\tName: {lim_p[0].name}\n\tID: {lim_p[0].id}") + return lim_p[0] + + def execute(self): + raise NotImplementedError("Method `execute` must be implemented in the subclass.") + + +class TableauHyperPublishMode(str, Enum): + """ + Publishing modes for the TableauHyperPublisher. + """ + + APPEND = Server.PublishMode.Append + OVERWRITE = Server.PublishMode.Overwrite + + +class TableauHyperPublisher(TableauServer): + """ + Publish the given Hyper file to the Tableau server. Hyper file will be treated by Tableau server as a datasource. + """ + datasource_name: str = Field(default=..., description="Name of the datasource to publish") + hyper_path: PurePath = Field(default=..., description="Path to Hyper file") + publish_mode: TableauHyperPublishMode = Field( + default=TableauHyperPublishMode.OVERWRITE, + description="Publish mode for the Hyper file", + ) + + class Output(StepOutput): + """ + Output class for TableauHyperPublisher + """ + + datasource_item: DatasourceItem = Field( + default=..., description="DatasourceItem object representing the published datasource" + ) + + def execute(self): + # Ensure that the Hyper File exists + if not os.path.isfile(self.hyper_path): + raise FileNotFoundError(f"Hyper file not found at: {self.hyper_path.as_posix()}") + + with self.auth: + # Finally, publish the Hyper File to the Tableau server + self.log.info(f'Publishing Hyper File located at: "{self.hyper_path.as_posix()}"') + self.log.debug(f"Create mode: {self.publish_mode}") + + datasource_item = self.server.datasources.publish( + datasource_item=DatasourceItem(project_id=self.working_project.id, name=self.datasource_name), + file=self.hyper_path.as_posix(), + mode=self.publish_mode, + ) + self.log.info(f"Published datasource to Tableau server with the id: {datasource_item.id}") + + self.output.datasource_item = datasource_item + + def publish(self): + self.execute() diff --git a/tests/_data/readers/hyper_file/dummy.hyper b/tests/_data/readers/hyper_file/dummy.hyper new file mode 100644 index 0000000..e4a835c Binary files /dev/null and b/tests/_data/readers/hyper_file/dummy.hyper differ diff --git a/tests/_data/readers/parquet_file/.part-00000-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet.crc b/tests/_data/readers/parquet_file/.part-00000-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet.crc new file mode 100644 index 0000000..618ce37 Binary files /dev/null and b/tests/_data/readers/parquet_file/.part-00000-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet.crc differ diff --git a/tests/_data/readers/parquet_file/.part-00001-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet.crc b/tests/_data/readers/parquet_file/.part-00001-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet.crc new file mode 100644 index 0000000..259c6a4 Binary files /dev/null and b/tests/_data/readers/parquet_file/.part-00001-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet.crc differ diff --git a/tests/_data/readers/parquet_file/part-00000-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet b/tests/_data/readers/parquet_file/part-00000-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet new file mode 100644 index 0000000..c97d648 Binary files /dev/null and b/tests/_data/readers/parquet_file/part-00000-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet differ diff --git a/tests/_data/readers/parquet_file/part-00001-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet b/tests/_data/readers/parquet_file/part-00001-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet new file mode 100644 index 0000000..70cc5b2 Binary files /dev/null and b/tests/_data/readers/parquet_file/part-00001-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet differ diff --git a/tests/spark/integrations/tableau/test_hyper.py b/tests/spark/integrations/tableau/test_hyper.py new file mode 100644 index 0000000..73dcfef --- /dev/null +++ b/tests/spark/integrations/tableau/test_hyper.py @@ -0,0 +1,109 @@ +from datetime import datetime +from pathlib import Path, PurePath + +import pytest + +from koheesio.integrations.spark.tableau.hyper import ( + NOT_NULLABLE, + NULLABLE, + HyperFileDataFrameWriter, + HyperFileListWriter, + HyperFileParquetWriter, + HyperFileReader, + SqlType, + TableDefinition, + TableName, +) + +pytestmark = pytest.mark.spark + + +class TestHyper: + @pytest.fixture() + def parquet_file(self, data_path): + path = f"{data_path}/readers/parquet_file" + return Path(path).glob("**/*.parquet") + + @pytest.fixture() + def hyper_file(self, data_path): + return f"{data_path}/readers/hyper_file/dummy.hyper" + + def test_hyper_file_reader(self, hyper_file): + df = ( + HyperFileReader( + path=hyper_file, + ) + .execute() + .df + ) + + assert df.count() == 3 + assert df.dtypes == [("string", "string"), ("int", "int"), ("timestamp", "timestamp")] + + def test_hyper_file_list_writer(self, spark): + hw = HyperFileListWriter( + name="test", + table_definition=TableDefinition( + table_name=TableName("Extract", "Extract"), + columns=[ + TableDefinition.Column(name="string", type=SqlType.text(), nullability=NOT_NULLABLE), + TableDefinition.Column(name="int", type=SqlType.int(), nullability=NULLABLE), + TableDefinition.Column(name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE), + ], + ), + data=[ + ["text_1", 1, datetime(2024, 1, 1, 0, 0, 0, 0)], + ["text_2", 2, datetime(2024, 1, 2, 0, 0, 0, 0)], + ["text_3", None, None], + ], + ).execute() + + df = ( + HyperFileReader( + path=PurePath(hw.hyper_path), + ) + .execute() + .df + ) + + assert df.count() == 3 + assert df.dtypes == [("string", "string"), ("int", "int"), ("timestamp", "timestamp")] + + def test_hyper_file_parquet_writer(self, data_path, parquet_file): + hw = HyperFileParquetWriter( + name="test", + table_definition=TableDefinition( + table_name=TableName("Extract", "Extract"), + columns=[ + TableDefinition.Column(name="string", type=SqlType.text(), nullability=NOT_NULLABLE), + TableDefinition.Column(name="int", type=SqlType.int(), nullability=NULLABLE), + TableDefinition.Column(name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE), + ], + ), + files=parquet_file, + ).execute() + + df = HyperFileReader(path=PurePath(hw.hyper_path)).execute().df + + assert df.count() == 6 + assert df.dtypes == [("string", "string"), ("int", "int"), ("timestamp", "timestamp")] + + def test_hyper_file_dataframe_writer(self, data_path, df_with_all_types): + hw = HyperFileDataFrameWriter( + name="test", + df=df_with_all_types.drop("void", "byte", "binary", "array", "map", "float"), + ).execute() + + df = HyperFileReader(path=PurePath(hw.hyper_path)).execute().df + assert df.count() == 1 + assert df.dtypes == [ + ("short", "smallint"), + ("integer", "int"), + ("long", "bigint"), + ("double", "float"), + ("decimal", "decimal(18,5)"), + ("string", "string"), + ("boolean", "boolean"), + ("timestamp", "timestamp"), + ("date", "date"), + ] diff --git a/tests/spark/integrations/tableau/test_server.py b/tests/spark/integrations/tableau/test_server.py new file mode 100644 index 0000000..e2dc5ec --- /dev/null +++ b/tests/spark/integrations/tableau/test_server.py @@ -0,0 +1,126 @@ +from typing import Any + +import pytest +from tableauserverclient import DatasourceItem + +from koheesio.integrations.spark.tableau.server import TableauHyperPublisher + + +class TestTableauServer: + @pytest.fixture(autouse=False) + def server(self, mocker): + __server = mocker.patch("koheesio.integrations.spark.tableau.server.Server") + __mock_server = __server.return_value + + from koheesio.integrations.spark.tableau.server import TableauServer + + # Mocking various returns from the Tableau server + def create_mock_object(name_prefix: str, object_id: int, spec: Any = None, project_id: int = None): + obj = mocker.MagicMock() if not spec else mocker.MagicMock(spec=spec) + obj.id = f"{object_id}" + obj.name = f"{name_prefix}-{object_id}" + obj.project_id = f"{project_id}" if project_id else None + return obj + + def create_mock_pagination(length: int): + obj = mocker.MagicMock() + obj.total_available = length + return obj + + # Projects + mock_projects = [] + mock_projects_pagination = mocker.MagicMock() + + def create_mock_project(name_prefix: str, project_id: int, parent_id: int = None): + mock_project = mocker.MagicMock() + mock_project.name = f"{name_prefix}-{project_id}" + mock_project.id = f"{project_id}" if not parent_id else f"{project_id * parent_id}" + mock_project.parent_id = f"{parent_id}" if parent_id else None + return mock_project + + # Parent Projects + for i in range(1, 3): + mock_projects.append(create_mock_project(name_prefix="parent-project", project_id=i)) + # Child Projects + r = range(3, 5) if i == 1 else range(4, 6) + for ix in r: + mock_projects.append(create_mock_project(name_prefix="project", project_id=ix, parent_id=i)) + + mock_projects_pagination.total_available = len(mock_projects) + + __mock_server.projects.get.return_value = [ + mock_projects, + mock_projects_pagination, + ] + + # Data Sources + mock_ds = create_mock_object("datasource", 1) + mock_conn = create_mock_object("connection", 1) + mock_conn.type = "baz" + mock_ds.connections = [mock_conn] + __mock_server.datasources.get.return_value = [ + [mock_ds], + create_mock_pagination(1), + ] + + __mock_server.datasources.publish.return_value = create_mock_object( + "published_datasource", 1337, spec=DatasourceItem + ) + + yield TableauServer( + url="https://tableau.domain.com", user="user", password="pass", site_id="site", project_id="1" + ) + + @pytest.fixture() + def hyper_file(self, data_path): + return f"{data_path}/readers/hyper_file/dummy.hyper" + + def test_working_project_w_project_id(self, server): + server.project_id = "3" + wp = server.working_project + assert wp.id == "3" and wp.name == "project-3" + + def test_working_project_w_project_name(self, server): + server.project_id = None + server.project = "project-5" + wp = server.working_project + assert wp.id == "10" and wp.name == "project-5" + + def test_working_project_w_project_name_and_parent_project(self, server): + server.project_id = None + server.project = "project-4" + server.parent_project = "parent-project-1" + wp = server.working_project + assert wp.id == "4" and wp.name == "project-4" + + def test_working_project_w_project_name_and_root(self, server): + server.project_id = None + server.project = "parent-project-1" + server.parent_project = "root" + wp = server.working_project + assert wp.id == "1" and wp.name == "parent-project-1" + + def test_working_project_multiple_projects(self, server): + with pytest.raises(ValueError): + server.project_id = None + server.project = "project-4" + server.working_project + + def test_working_project_unknown(self, server): + with pytest.raises(ValueError): + server.project_id = None + server.project = "project-6" + server.working_project + + def test_publish_hyper(self, server, hyper_file): + p = TableauHyperPublisher( + url="https://tableau.domain.com", + user="user", + password="pass", + site_id="site", + project_id="1", + hyper_path=hyper_file, + datasource_name="published_datasource", + ) + p.publish() + assert p.output.datasource_item.id == "1337"