From c0e305349fc557e50709a6f1a4f1f8d889236073 Mon Sep 17 00:00:00 2001 From: Max <43565398+maxim-mityutko@users.noreply.github.com> Date: Wed, 4 Sep 2024 14:42:58 +0200 Subject: [PATCH] [FEATURE] Tableau Hyper (#49) Provide the capability to write into the Tableau Hyper format, that allows loading datasource data independently from the Tableau server. ## Description Writing data into hyper from List, Parquet, Dataframe. Reading data from Hyper into Dataframe. Publishing Hyper to the Tableau server. ## Related Issue #45 ## Motivation and Context In certain scenarious loading data into the datasource by means of Tableau server may take too much time (and depending on Tableau server timeouts, may even fail). Creating Hyper files independently from Tableau allows much faster load times. ## How Has This Been Tested? Unit tests / Publishing to the Tableau server ## Screenshots (if appropriate): ## Types of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [ ] My change requires a change to the documentation. - [ ] I have updated the documentation accordingly. - [x] I have read the **CONTRIBUTING** document. - [x] I have added tests to cover my changes. - [x] All new and existing tests passed. --- pyproject.toml | 9 + .../integrations/spark/tableau/__init__.py | 0 .../integrations/spark/tableau/hyper.py | 431 ++++++++++++++++++ .../integrations/spark/tableau/server.py | 229 ++++++++++ tests/_data/readers/hyper_file/dummy.hyper | Bin 0 -> 65536 bytes ...-95f7-5c8b70b35e09-c000.snappy.parquet.crc | Bin 0 -> 16 bytes ...-95f7-5c8b70b35e09-c000.snappy.parquet.crc | Bin 0 -> 16 bytes ...4ded-95f7-5c8b70b35e09-c000.snappy.parquet | Bin 0 -> 1010 bytes ...4ded-95f7-5c8b70b35e09-c000.snappy.parquet | Bin 0 -> 998 bytes .../spark/integrations/tableau/test_hyper.py | 109 +++++ .../spark/integrations/tableau/test_server.py | 126 +++++ 11 files changed, 904 insertions(+) create mode 100644 src/koheesio/integrations/spark/tableau/__init__.py create mode 100644 src/koheesio/integrations/spark/tableau/hyper.py create mode 100644 src/koheesio/integrations/spark/tableau/server.py create mode 100644 tests/_data/readers/hyper_file/dummy.hyper create mode 100644 tests/_data/readers/parquet_file/.part-00000-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet.crc create mode 100644 tests/_data/readers/parquet_file/.part-00001-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet.crc create mode 100644 tests/_data/readers/parquet_file/part-00000-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet create mode 100644 tests/_data/readers/parquet_file/part-00001-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet create mode 100644 tests/spark/integrations/tableau/test_hyper.py create mode 100644 tests/spark/integrations/tableau/test_server.py diff --git a/pyproject.toml b/pyproject.toml index a99d431..be0a549 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ authors = [ # TODO: add other contributors { name = "Danny Meijer", email = "danny.meijer@nike.com" }, { name = "Mikita Sakalouski", email = "mikita.sakalouski@nike.com" }, + { name = "Maxim Mityutko", email = "maxim.mityutko@nike.com" }, ] classifiers = [ "Development Status :: 5 - Production/Stable", @@ -63,6 +64,11 @@ se = ["spark-expectations>=2.1.0"] sftp = ["paramiko>=2.6.0"] delta = ["delta-spark>=2.2"] excel = ["openpyxl>=3.0.0"] +# Tableau dependencies +tableau = [ + "tableauhyperapi>=0.0.19484", + "tableauserverclient>=0.25", +] dev = ["black", "isort", "ruff", "mypy", "pylint", "colorama", "types-PyYAML"] test = [ "chispa", @@ -180,6 +186,7 @@ features = [ "excel", "se", "box", + "tableau", "dev", ] @@ -244,6 +251,7 @@ features = [ "sftp", "delta", "excel", + "tableau", "dev", "test", ] @@ -399,6 +407,7 @@ features = [ "sftp", "delta", "excel", + "tableau", "dev", "test", "docs", diff --git a/src/koheesio/integrations/spark/tableau/__init__.py b/src/koheesio/integrations/spark/tableau/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/koheesio/integrations/spark/tableau/hyper.py b/src/koheesio/integrations/spark/tableau/hyper.py new file mode 100644 index 0000000..b5f4f5c --- /dev/null +++ b/src/koheesio/integrations/spark/tableau/hyper.py @@ -0,0 +1,431 @@ +import os +from typing import Any, List, Optional, Union +from abc import ABC, abstractmethod +from pathlib import PurePath +from tempfile import TemporaryDirectory + +from tableauhyperapi import ( + NOT_NULLABLE, + NULLABLE, + Connection, + CreateMode, + HyperProcess, + Inserter, + SqlType, + TableDefinition, + TableName, + Telemetry, +) + +from pydantic import Field, conlist + +from pyspark.sql import DataFrame +from pyspark.sql.functions import col +from pyspark.sql.types import ( + BooleanType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + ShortType, + StringType, + StructField, + StructType, + TimestampType, +) + +from koheesio.spark.readers import SparkStep +from koheesio.spark.transformations.cast_to_datatype import CastToDatatype +from koheesio.spark.utils import spark_minor_version +from koheesio.steps import Step, StepOutput + + +class HyperFile(Step, ABC): + """ + Base class for all HyperFile classes + """ + + schema_: str = Field(default="Extract", alias="schema", description="Internal schema name within the Hyper file") + table: str = Field(default="Extract", description="Table name within the Hyper file") + + @property + def table_name(self) -> TableName: + """ + Return TableName object for the Hyper file TableDefinition. + """ + return TableName(self.schema_, self.table) + + +class HyperFileReader(HyperFile, SparkStep): + """ + Read a Hyper file and return a Spark DataFrame. + + Examples + -------- + ```python + df = HyperFileReader( + path=PurePath(hw.hyper_path), + ).execute().df + ``` + """ + + path: PurePath = Field( + default=..., description="Path to the Hyper file", examples=["PurePath(~/data/my-file.hyper)"] + ) + + def execute(self): + type_mapping = { + "date": StringType, + "text": StringType, + "double": FloatType, + "bool": BooleanType, + "small_int": ShortType, + "big_int": LongType, + "timestamp": StringType, + "timestamp_tz": StringType, + "int": IntegerType, + "numeric": DecimalType, + } + df_cols = [] + timestamp_cols = [] + date_cols = [] + + with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: + with Connection(endpoint=hp.endpoint, database=self.path) as connection: + table_definition = connection.catalog.get_table_definition(name=self.table_name) + + select_cols = [] + self.log.debug(f"Schema for {self.table_name} in {self.path}:") + for column in table_definition.columns: + self.log.debug(f"|-- {column.name}: {column.type} (nullable = {column.nullability})") + + column_name = column.name.unescaped.__str__() + tableau_type = column.type.__str__().lower() + + if tableau_type.startswith("numeric"): + spark_type = DecimalType(precision=18, scale=5) + else: + spark_type = type_mapping.get(tableau_type, StringType)() + + if tableau_type == "timestamp" or tableau_type == "timestamp_tz": + timestamp_cols.append(column_name) + _col = f'cast("{column_name}" as text)' + elif tableau_type == "date": + date_cols.append(column_name) + _col = f'cast("{column_name}" as text)' + elif tableau_type.startswith("numeric"): + _col = f'cast("{column_name}" as decimal(18,5))' + else: + _col = f'"{column_name}"' + + df_cols.append(StructField(column_name, spark_type)) + select_cols.append(_col) + + data = connection.execute_list_query(f"select {','.join(select_cols)} from {self.table_name}") + + df_schema = StructType(df_cols) + df = self.spark.createDataFrame(data, schema=df_schema) + if timestamp_cols: + df = CastToDatatype(column=timestamp_cols, datatype="timestamp").transform(df) + if date_cols: + df = CastToDatatype(column=date_cols, datatype="date").transform(df) + + self.output.df = df + + +class HyperFileWriter(HyperFile): + """ + Base class for all HyperFileWriter classes + """ + + path: PurePath = Field( + default=TemporaryDirectory().name, description="Path to the Hyper file", examples=["PurePath(/tmp/hyper/)"] + ) + name: str = Field(default="extract", description="Name of the Hyper file") + table_definition: TableDefinition = Field( + default=None, + description="Table definition to write to the Hyper file as described in " + "https://tableau.github.io/hyper-db/lang_docs/py/tableauhyperapi.html#tableauhyperapi.TableDefinition", + ) + + class Output(StepOutput): + """ + Output class for HyperFileListWriter + """ + + hyper_path: PurePath = Field(default=..., description="Path to created Hyper file") + + @property + def hyper_path(self) -> Connection: + """ + Return full path to the Hyper file. + """ + if not os.path.exists(self.path): + os.makedirs(self.path) + + hyper_path = PurePath(self.path, f"{self.name}.hyper" if ".hyper" not in self.name else self.name) + self.log.info(f"Destination file: {hyper_path}") + return hyper_path + + def write(self): + self.execute() + + @abstractmethod + def execute(self): + pass + + +class HyperFileListWriter(HyperFileWriter): + """ + Write list of rows to a Hyper file. + + Reference + --------- + Datatypes in https://tableau.github.io/hyper-db/docs/sql/datatype/ for supported data types. + + Examples + -------- + ```python + hw = HyperFileListWriter( + name="test", + table_definition=TableDefinition( + table_name=TableName("Extract", "Extract"), + columns=[ + TableDefinition.Column(name="string", type=SqlType.text(), nullability=NOT_NULLABLE), + TableDefinition.Column(name="int", type=SqlType.int(), nullability=NULLABLE), + TableDefinition.Column(name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE), + ] + ), + data=[ + ["text_1", 1, datetime(2024, 1, 1, 0, 0, 0, 0)], + ["text_2", 2, datetime(2024, 1, 2, 0, 0, 0, 0)], + ["text_3", None, None], + ], + ).execute() + + # do somthing with returned file path + hw.hyper_path + ``` + """ + + data: conlist(List[Any], min_length=1) = Field(default=..., description="List of rows to write to the Hyper file") + + def execute(self): + with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: + with Connection( + endpoint=hp.endpoint, database=self.hyper_path, create_mode=CreateMode.CREATE_AND_REPLACE + ) as connection: + connection.catalog.create_schema(schema=self.table_definition.table_name.schema_name) + connection.catalog.create_table(table_definition=self.table_definition) + with Inserter(connection, self.table_definition) as inserter: + inserter.add_rows(rows=self.data) + inserter.execute() + + self.output.hyper_path = self.hyper_path + + +class HyperFileParquetWriter(HyperFileWriter): + """ + Read one or multiple parquet files and write them to a Hyper file. + + Notes + ----- + This method is much faster than HyperFileListWriter for large files. + + References + ---------- + Copy from external format: https://tableau.github.io/hyper-db/docs/sql/command/copy_from + Datatypes in https://tableau.github.io/hyper-db/docs/sql/datatype/ for supported data types. + Parquet format limitations: + https://tableau.github.io/hyper-db/docs/sql/external/formats/#external-format-parquet + + Examples + -------- + ```python + hw = HyperFileParquetWriter( + name="test", + table_definition=TableDefinition( + table_name=TableName("Extract", "Extract"), + columns=[ + TableDefinition.Column(name="string", type=SqlType.text(), nullability=NOT_NULLABLE), + TableDefinition.Column(name="int", type=SqlType.int(), nullability=NULLABLE), + TableDefinition.Column(name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE), + ] + ), + files=["/my-path/parquet-1.snappy.parquet","/my-path/parquet-2.snappy.parquet"] + ).execute() + + # do somthing with returned file path + hw.hyper_path + ``` + """ + + file: conlist(Union[str, PurePath], min_length=1) = Field( + default=..., alias="files", description="One or multiple parquet files to write to the Hyper file" + ) + + def execute(self): + _file = [str(f) for f in self.file] + array_files = "'" + "','".join(_file) + "'" + + with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: + with Connection( + endpoint=hp.endpoint, database=self.hyper_path, create_mode=CreateMode.CREATE_AND_REPLACE + ) as connection: + connection.catalog.create_schema(schema=self.table_definition.table_name.schema_name) + connection.catalog.create_table(table_definition=self.table_definition) + sql = f'copy "{self.schema_}"."{self.table}" ' f"from array [{array_files}] " f"with (format parquet)" + self.log.debug(f"Executing SQL: {sql}") + connection.execute_command(sql) + + self.output.hyper_path = self.hyper_path + + +class HyperFileDataFrameWriter(HyperFileWriter): + """ + Write a Spark DataFrame to a Hyper file. + The process will write the DataFrame to a parquet file and then use the HyperFileParquetWriter to write to the + Hyper file. + + Examples + -------- + ```python + hw = HyperFileDataFrameWriter( + df=spark.createDataFrame([(1, "foo"), (2, "bar")], ["id", "name"]), + name="test", + ).execute() + + # do somthing with returned file path + hw.hyper_path + ``` + """ + df: DataFrame = Field(default=..., description="Spark DataFrame to write to the Hyper file") + table_definition: Optional[TableDefinition] = None # table_definition is not required for this class + + @staticmethod + def table_definition_column(column: StructField) -> TableDefinition.Column: + """ + Convert a Spark StructField to a Tableau Hyper SqlType + """ + type_mapping = { + IntegerType(): SqlType.int, + LongType(): SqlType.big_int, + ShortType(): SqlType.small_int, + DoubleType(): SqlType.double, + FloatType(): SqlType.double, + BooleanType(): SqlType.bool, + DateType(): SqlType.date, + StringType(): SqlType.text, + } + + # Handling the TimestampNTZType for Spark 3.4+ + # Mapping both TimestampType and TimestampNTZType to NTZ type of Hyper + if spark_minor_version >= 3.4: + from pyspark.sql.types import TimestampNTZType + + type_mapping[TimestampNTZType()] = SqlType.timestamp + type_mapping[TimestampType()] = SqlType.timestamp + # In older versions of Spark, only TimestampType is available and is mapped to TZ type of Hyper + else: + type_mapping[TimestampType()] = SqlType.timestamp_tz + + if column.dataType in type_mapping: + sql_type = type_mapping[column.dataType]() + elif str(column.dataType).startswith("DecimalType"): + # Tableau Hyper API limits the precision to 18 decimal places + # noinspection PyUnresolvedReferences + sql_type = SqlType.numeric( + precision=column.dataType.precision if column.dataType.precision <= 18 else 18, + scale=column.dataType.scale, + ) + else: + raise ValueError(f"Unsupported datatype '{column.dataType}' for column '{column.name}'.") + + return TableDefinition.Column( + name=column.name, type=sql_type, nullability=NULLABLE if column.nullable else NOT_NULLABLE + ) + + @property + def _table_definition(self) -> TableDefinition: + schema = self.df.schema + columns = list(map(self.table_definition_column, schema)) + + td = TableDefinition(table_name=self.table_name, columns=columns) + self.log.debug(f"Table definition for {self.table_name}:") + for column in td.columns: + self.log.debug(f"|-- {column.name}: {column.type} (nullable = {column.nullability})") + + return td + + def clean_dataframe(self) -> DataFrame: + """ + - Replace NULLs for string and numeric columns + - Convert data types to ensure compatibility with Tableau Hyper API + """ + _df = self.df + _schema = self.df.schema + + integer_cols = [field.name for field in _schema if field.dataType == IntegerType()] + long_cols = [field.name for field in _schema if field.dataType == LongType()] + double_cols = [field.name for field in _schema if field.dataType == DoubleType()] + float_cols = [field.name for field in _schema if field.dataType == FloatType()] + string_cols = [field.name for field in _schema if field.dataType == StringType()] + decimal_cols = [field for field in _schema if str(field.dataType).startswith("DecimalType")] + timestamp_cols = [field.name for field in _schema if field.dataType == TimestampType()] + + # Cast decimal fields to DecimalType(18,3) + for d_col in decimal_cols: + # noinspection PyUnresolvedReferences + if d_col.dataType.precision > 18: + _df = self.df.withColumn(d_col.name, col(d_col.name).cast(DecimalType(precision=18, scale=5))) + + # Handling the TimestampNTZType for Spark 3.4+ + # Any TimestampType column will be cast to TimestampNTZType for compatibility with Tableau Hyper API + if spark_minor_version >= 3.4: + from pyspark.sql.types import TimestampNTZType + + for t_col in timestamp_cols: + _df = _df.withColumn(t_col, col(t_col).cast(TimestampNTZType())) + + # Replace null and NaN values with 0 + if len(integer_cols) > 0: + _df = _df.na.fill(0, integer_cols) + elif len(long_cols) > 0: + _df = _df.na.fill(0, long_cols) + elif len(double_cols) > 0: + _df = _df.na.fill(0.0, double_cols) + elif len(float_cols) > 0: + _df = _df.na.fill(0.0, float_cols) + elif len(decimal_cols) > 0: + _df = _df.na.fill(0.0, decimal_cols) + elif len(string_cols) > 0: + _df = _df.na.fill("", string_cols) + + return _df + + def write_parquet(self): + _path = self.path.joinpath("parquet") + ( + self.clean_dataframe() + .coalesce(1) + .write.option("delimiter", ",") + .option("header", "true") + .mode("overwrite") + .parquet(_path.as_posix()) + ) + + for _, _, files in os.walk(_path): + for file in files: + if file.endswith(".parquet"): + fp = PurePath(_path, file) + self.log.info("Parquet file created: %s", fp) + return [fp] + + def execute(self): + w = HyperFileParquetWriter( + path=self.path, name=self.name, table_definition=self._table_definition, files=self.write_parquet() + ) + w.execute() + self.output.hyper_path = w.output.hyper_path diff --git a/src/koheesio/integrations/spark/tableau/server.py b/src/koheesio/integrations/spark/tableau/server.py new file mode 100644 index 0000000..023305d --- /dev/null +++ b/src/koheesio/integrations/spark/tableau/server.py @@ -0,0 +1,229 @@ +import os +from typing import ContextManager, Optional, Union +from enum import Enum +from pathlib import PurePath + +import urllib3 +from tableauserverclient import ( + DatasourceItem, + Pager, + PersonalAccessTokenAuth, + ProjectItem, + Server, + TableauAuth, +) + +from pydantic import Field, SecretStr + +from koheesio.models import model_validator +from koheesio.steps import Step, StepOutput + + +class TableauServer(Step): + """ + Base class for Tableau server interactions. Class provides authentication and project identification functionality. + """ + url: str = Field( + default=..., + alias="url", + description="Hostname for the Tableau server, e.g. tableau.my-org.com", + examples=["tableau.my-org.com"], + ) + user: str = Field(default=..., alias="user", description="Login name for the Tableau user") + password: SecretStr = Field(default=..., alias="password", description="Password for the Tableau user") + site_id: str = Field( + default=..., + alias="site_id", + description="Identifier for the Tableau site, as used in the URL: https://tableau.my-org.com/#/site/SITE_ID", + ) + version: str = Field( + default="3.14", + alias="version", + description="Version of the Tableau server API", + ) + token_name: Optional[str] = Field( + default=None, + alias="token_name", + description="Name of the Tableau Personal Access Token", + ) + token_value: Optional[SecretStr] = Field( + default=None, + alias="token_value", + description="Value of the Tableau Personal Access Token", + ) + project: Optional[str] = Field( + default=None, + alias="project", + description="Name of the project on the Tableau server", + ) + parent_project: Optional[str] = Field( + default=None, + alias="parent_project", + description="Name of the parent project on the Tableau server, use 'root' for the root project.", + ) + project_id: Optional[str] = Field( + default=None, + alias="project_id", + description="ID of the project on the Tableau server", + ) + + def __init__(self, **data): + super().__init__(**data) + self.server = None + + @model_validator(mode="after") + def validate_project(cls, data: dict) -> dict: + """Validate when project and project_id are provided at the same time.""" + project = data.get("project") + project_id = data.get("project_id") + + if project and project_id: + raise ValueError("Both 'project' and 'project_id' parameters cannot be provided at the same time.") + + if not project and not project_id: + raise ValueError("Either 'project' or 'project_id' parameters should be provided, none is set") + + @property + def auth(self) -> ContextManager: + """ + Authenticate on the Tableau server. + + Examples + -------- + ```python + with self._authenticate(): + self.server.projects.get() + ``` + + Returns + ------- + ContextManager for TableauAuth or PersonalAccessTokenAuth authorization object + """ + # Suppress 'InsecureRequestWarning' + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + tableau_auth = TableauAuth(username=self.user, password=self.password.get_secret_value(), site_id=self.site_id) + + if self.token_name and self.token_value: + self.log.info( + "Token details provided, this will take precedence over username and password authentication." + ) + tableau_auth = PersonalAccessTokenAuth( + token_name=self.token_name, personal_access_token=self.token_value, site_id=self.site_id + ) + + self.server = Server(self.url) + self.server.version = self.version + self.server.add_http_options({"verify": False}) + + return self.server.auth.sign_in(tableau_auth) + + @property + def working_project(self) -> Union[ProjectItem, None]: + """ + Identify working project by using `project` and `parent_project` (if necessary) class properties. + The goal is to uniquely identify specific project on the server. If multiple projects have the same + name, the `parent_project` attribute of the TableauServer is required. + + Notes + ----- + Set `parent_project` value to 'root' if the project is located in the root directory. + + If `id` of the project is known, it can be used in `project_id` parameter, then the detection of the working + project using the `project` and `parent_project` attributes is skipped. + + Returns + ------- + ProjectItem object representing the working project + """ + + with self.auth: + all_projects = Pager(self.server.projects) + parent, lim_p = None, [] + + for project in all_projects: + if project.id == self.project_id: + lim_p = [project] + self.log.info(f"\nProject ID provided directly:\n\tName: {lim_p[0].name}\n\tID: {lim_p[0].id}") + break + + # Identify parent project + if project.name.strip() == self.parent_project and not self.project_id: + parent = project + self.log.info(f"\nParent project identified:\n\tName: {parent.name}\n\tID: {parent.id}") + + # Identify project(s) + if project.name.strip() == self.project and not self.project_id: + lim_p.append(project) + + # Further filter the list of projects by parent project id + if self.parent_project == "root" and not self.project_id: + lim_p = [p for p in lim_p if not p.parent_id] + elif self.parent_project and parent and not self.project_id: + lim_p = [p for p in lim_p if p.parent_id == parent.id] + + if len(lim_p) > 1: + raise ValueError( + "Multiple projects with the same name exist on the server, " + "please provide `parent_project` attribute." + ) + elif len(lim_p) == 0: + raise ValueError("Working project could not be identified.") + else: + self.log.info(f"\nWorking project identified:\n\tName: {lim_p[0].name}\n\tID: {lim_p[0].id}") + return lim_p[0] + + def execute(self): + raise NotImplementedError("Method `execute` must be implemented in the subclass.") + + +class TableauHyperPublishMode(str, Enum): + """ + Publishing modes for the TableauHyperPublisher. + """ + + APPEND = Server.PublishMode.Append + OVERWRITE = Server.PublishMode.Overwrite + + +class TableauHyperPublisher(TableauServer): + """ + Publish the given Hyper file to the Tableau server. Hyper file will be treated by Tableau server as a datasource. + """ + datasource_name: str = Field(default=..., description="Name of the datasource to publish") + hyper_path: PurePath = Field(default=..., description="Path to Hyper file") + publish_mode: TableauHyperPublishMode = Field( + default=TableauHyperPublishMode.OVERWRITE, + description="Publish mode for the Hyper file", + ) + + class Output(StepOutput): + """ + Output class for TableauHyperPublisher + """ + + datasource_item: DatasourceItem = Field( + default=..., description="DatasourceItem object representing the published datasource" + ) + + def execute(self): + # Ensure that the Hyper File exists + if not os.path.isfile(self.hyper_path): + raise FileNotFoundError(f"Hyper file not found at: {self.hyper_path.as_posix()}") + + with self.auth: + # Finally, publish the Hyper File to the Tableau server + self.log.info(f'Publishing Hyper File located at: "{self.hyper_path.as_posix()}"') + self.log.debug(f"Create mode: {self.publish_mode}") + + datasource_item = self.server.datasources.publish( + datasource_item=DatasourceItem(project_id=self.working_project.id, name=self.datasource_name), + file=self.hyper_path.as_posix(), + mode=self.publish_mode, + ) + self.log.info(f"Published datasource to Tableau server with the id: {datasource_item.id}") + + self.output.datasource_item = datasource_item + + def publish(self): + self.execute() diff --git a/tests/_data/readers/hyper_file/dummy.hyper b/tests/_data/readers/hyper_file/dummy.hyper new file mode 100644 index 0000000000000000000000000000000000000000..e4a835caed37f6492c28b0c19580eadca2e5159c GIT binary patch literal 65536 zcmeI*Piz!b9Ki9{?Vl`(h6?^c#B?2uR#3A2qln?KXi18KHiB|MQl`6)c4T*UnVscd z)3}Kt9zE!h1E`5dG|?C@TI0cx7!t0?)p!t3;-7;_{k{3KJ6kL*TA?l9JuvhBzxSCp zlg{vFclcyQdugTG)IWbV>~6-J{Zu0umdKqBBP!W&TaUJJwvGFL-2BL*;QSY@T67n%9V;XX}(v}QeKB2h8Z~LdkKP!F<_u zi%u!1aA*&L6R+OlA5+IEjNv9kgU(k2(5yBXGFp z=1p^r2vbcXV2_V`dR(pqjhL!FT9Yed6ft$-!z~x*xVSGcx;@J0yfgH}0*ka{J1c9) ztae#GW@Xc9wKkVlYo1!IHmZ?s)sa#kbV+7k97g-rsLpP+W}sWGPN|uhnfE)cwTF8% zdpZgoDw7`Rx=X39RceP6S=XhyQYy7O-PWb<8R$~?^>0ukLA&36cjSp=2M%ST_BUlR z()L=@c6OaXSrMSFZf=ZXy`% z)^qXrGcxvI{5{e5J5%bLvlrgGchIbUJNIdCTFEo=w$2D(&Q6$wkJOn=PwGtOhM8t& z#?>1=YA7fDU6oR+d!;MV?L(a{-QHh+;EU_Ky=XD<{4opsE{*>zrSEs1e-)umUrHR&+O;J z%XX-A903FnKmY**5I_I{1Q0;r4i;$M-O5+T1)q3*RxS!Rb3_qr5I_I{1Q0*~0R#|0 z009ILSiu5~|0&LdgRDn4%(>HHc=}Yho$r07HAB5e!<-BIm-i1A=0q+8R+2z?y@?i~ z$iUu=1~>g!xKoe9Z9JfhB-kK;00IagfB*srAb`M)1v(qa@ohxCH;l?fDRF!wBRh*I z(0Hn(zKRb;CaT{2vY~u#m4Dby^M11uHm`ErZ$=Vq5I_I{1Q1v%0`afo20ua^bSlT+ zZ%gIi&<_LQ*<{!Wd*udY5A5>0;|ZgVe&+xgx{ zm<{{$+uI+R9Y+`e1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{ z1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0VEG9A E4aYX-Jpcdz literal 0 HcmV?d00001 diff --git a/tests/_data/readers/parquet_file/.part-00000-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet.crc b/tests/_data/readers/parquet_file/.part-00000-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet.crc new file mode 100644 index 0000000000000000000000000000000000000000..618ce379c3a937d04fb496a52eed1de9a44a0046 GIT binary patch literal 16 XcmYc;N@ieSU}AXw%lm5mkyR4`ERqJg literal 0 HcmV?d00001 diff --git a/tests/_data/readers/parquet_file/.part-00001-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet.crc b/tests/_data/readers/parquet_file/.part-00001-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet.crc new file mode 100644 index 0000000000000000000000000000000000000000..259c6a47a03d5ec45421fac38b72ab92040f5882 GIT binary patch literal 16 XcmYc;N@ieSU}EU(yRvJUa@b7(C$3bv)mKwID3*BIvww5J`f(OwamObo6q~znZrlx5(dG)Il zJa|+O{tY4s;@RV#z36SPg4K%`k9*U}N1K9(Q2J)x{N^_^zj-;haj0R04cx@})_G6B@RRSYL)@Sqpk_vy&QJi^5a68Vq?u0N=t zt{mekoS1ighHwF_P3Deh$js5CQtH0glIs9Li13mu8CB4-C{+Y#Gv&Y>;mMF#Rl#VZ zS~8u!oD%v1t)LOe|B}Zg@`1x`R2|CDT1x;$;<7QtAvvlJjbwN+R)U~fEJfn2Q!OH@ z1&38sJ;}+qkaA^>r%m|~HZsw`N;d=4D=D+v^%7@!OvvjzzJ|zKn6N6qQ39RoqBsr) zsyQ)1P&L+8eMp1Q9sA;jt~OO2iMxLx!edW`H5^YVQ-kp%bwruzn9LNv=Vl6fGNH=8 z0xOx?OsOk(8Twro@ZJfVEU#P(z>!FR`zG}d3{9yyUnt} zp#vJgMV1b=|DC;h4;jbms^_z(|6{s0N8;^t{Jl|7PW38ls@72Ta?P5$o>T639Gf~? sr(rj3t#3KhZriq{Yn@h0ziPG2D;jN6r)}*@20Ys2hrYqfI)JbJ4-pRQZ2$lO literal 0 HcmV?d00001 diff --git a/tests/_data/readers/parquet_file/part-00001-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet b/tests/_data/readers/parquet_file/part-00001-c7808dd3-0ba3-4ded-95f7-5c8b70b35e09-c000.snappy.parquet new file mode 100644 index 0000000000000000000000000000000000000000..70cc5b255f44b970b858b44ab2333bad01bf05f2 GIT binary patch literal 998 zcmah|ON-M`6uwQ{5VZ`bJ(mQ6lwfNIYM4yhSV|Dx$xwztTq#b;bKxxy}kByM95HqZy!!R=aK?`S+LdOuppl=;Jb)Fgf(CXK7IZ5 z=F$d?K1g8X&Y6343t3NJzrBH9=`tsYd~)8((1DXENm7L+uOt>?Y7BEpSXl8$qOhV` zR$*GE0$5nSv_`O>T8kVUE=>gW?v=u+89eQW_7gTXfhSO^5|Iz1!1ae!(vxFcl@s%> z9}y^^bs=)cER4+Yv|8@HyCOFbLI~w0Su$#Z3?->5V4EoiaFnMbYSjcoC-r1HeKjTa zg&JWaSii_anVxW4_IqM?;;1gY0jX3@eJSg3z7#VzH^V)Fu(P|3ZXEo{DQYnz2X?CePI|i%iFiOz~%Ku0$`VOzA4Pl9?ST z4Q1dm&k2?8qq*Xn<1~fa!c+!Le1GP7rsW}KV0s~IC@W&6J3l^dtO;ku%*KXO` nz;c+a+qTuzI_-Azmen?QG^R60x3(n%FKzOWA9%A4@E`sMLha+B literal 0 HcmV?d00001 diff --git a/tests/spark/integrations/tableau/test_hyper.py b/tests/spark/integrations/tableau/test_hyper.py new file mode 100644 index 0000000..73dcfef --- /dev/null +++ b/tests/spark/integrations/tableau/test_hyper.py @@ -0,0 +1,109 @@ +from datetime import datetime +from pathlib import Path, PurePath + +import pytest + +from koheesio.integrations.spark.tableau.hyper import ( + NOT_NULLABLE, + NULLABLE, + HyperFileDataFrameWriter, + HyperFileListWriter, + HyperFileParquetWriter, + HyperFileReader, + SqlType, + TableDefinition, + TableName, +) + +pytestmark = pytest.mark.spark + + +class TestHyper: + @pytest.fixture() + def parquet_file(self, data_path): + path = f"{data_path}/readers/parquet_file" + return Path(path).glob("**/*.parquet") + + @pytest.fixture() + def hyper_file(self, data_path): + return f"{data_path}/readers/hyper_file/dummy.hyper" + + def test_hyper_file_reader(self, hyper_file): + df = ( + HyperFileReader( + path=hyper_file, + ) + .execute() + .df + ) + + assert df.count() == 3 + assert df.dtypes == [("string", "string"), ("int", "int"), ("timestamp", "timestamp")] + + def test_hyper_file_list_writer(self, spark): + hw = HyperFileListWriter( + name="test", + table_definition=TableDefinition( + table_name=TableName("Extract", "Extract"), + columns=[ + TableDefinition.Column(name="string", type=SqlType.text(), nullability=NOT_NULLABLE), + TableDefinition.Column(name="int", type=SqlType.int(), nullability=NULLABLE), + TableDefinition.Column(name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE), + ], + ), + data=[ + ["text_1", 1, datetime(2024, 1, 1, 0, 0, 0, 0)], + ["text_2", 2, datetime(2024, 1, 2, 0, 0, 0, 0)], + ["text_3", None, None], + ], + ).execute() + + df = ( + HyperFileReader( + path=PurePath(hw.hyper_path), + ) + .execute() + .df + ) + + assert df.count() == 3 + assert df.dtypes == [("string", "string"), ("int", "int"), ("timestamp", "timestamp")] + + def test_hyper_file_parquet_writer(self, data_path, parquet_file): + hw = HyperFileParquetWriter( + name="test", + table_definition=TableDefinition( + table_name=TableName("Extract", "Extract"), + columns=[ + TableDefinition.Column(name="string", type=SqlType.text(), nullability=NOT_NULLABLE), + TableDefinition.Column(name="int", type=SqlType.int(), nullability=NULLABLE), + TableDefinition.Column(name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE), + ], + ), + files=parquet_file, + ).execute() + + df = HyperFileReader(path=PurePath(hw.hyper_path)).execute().df + + assert df.count() == 6 + assert df.dtypes == [("string", "string"), ("int", "int"), ("timestamp", "timestamp")] + + def test_hyper_file_dataframe_writer(self, data_path, df_with_all_types): + hw = HyperFileDataFrameWriter( + name="test", + df=df_with_all_types.drop("void", "byte", "binary", "array", "map", "float"), + ).execute() + + df = HyperFileReader(path=PurePath(hw.hyper_path)).execute().df + assert df.count() == 1 + assert df.dtypes == [ + ("short", "smallint"), + ("integer", "int"), + ("long", "bigint"), + ("double", "float"), + ("decimal", "decimal(18,5)"), + ("string", "string"), + ("boolean", "boolean"), + ("timestamp", "timestamp"), + ("date", "date"), + ] diff --git a/tests/spark/integrations/tableau/test_server.py b/tests/spark/integrations/tableau/test_server.py new file mode 100644 index 0000000..e2dc5ec --- /dev/null +++ b/tests/spark/integrations/tableau/test_server.py @@ -0,0 +1,126 @@ +from typing import Any + +import pytest +from tableauserverclient import DatasourceItem + +from koheesio.integrations.spark.tableau.server import TableauHyperPublisher + + +class TestTableauServer: + @pytest.fixture(autouse=False) + def server(self, mocker): + __server = mocker.patch("koheesio.integrations.spark.tableau.server.Server") + __mock_server = __server.return_value + + from koheesio.integrations.spark.tableau.server import TableauServer + + # Mocking various returns from the Tableau server + def create_mock_object(name_prefix: str, object_id: int, spec: Any = None, project_id: int = None): + obj = mocker.MagicMock() if not spec else mocker.MagicMock(spec=spec) + obj.id = f"{object_id}" + obj.name = f"{name_prefix}-{object_id}" + obj.project_id = f"{project_id}" if project_id else None + return obj + + def create_mock_pagination(length: int): + obj = mocker.MagicMock() + obj.total_available = length + return obj + + # Projects + mock_projects = [] + mock_projects_pagination = mocker.MagicMock() + + def create_mock_project(name_prefix: str, project_id: int, parent_id: int = None): + mock_project = mocker.MagicMock() + mock_project.name = f"{name_prefix}-{project_id}" + mock_project.id = f"{project_id}" if not parent_id else f"{project_id * parent_id}" + mock_project.parent_id = f"{parent_id}" if parent_id else None + return mock_project + + # Parent Projects + for i in range(1, 3): + mock_projects.append(create_mock_project(name_prefix="parent-project", project_id=i)) + # Child Projects + r = range(3, 5) if i == 1 else range(4, 6) + for ix in r: + mock_projects.append(create_mock_project(name_prefix="project", project_id=ix, parent_id=i)) + + mock_projects_pagination.total_available = len(mock_projects) + + __mock_server.projects.get.return_value = [ + mock_projects, + mock_projects_pagination, + ] + + # Data Sources + mock_ds = create_mock_object("datasource", 1) + mock_conn = create_mock_object("connection", 1) + mock_conn.type = "baz" + mock_ds.connections = [mock_conn] + __mock_server.datasources.get.return_value = [ + [mock_ds], + create_mock_pagination(1), + ] + + __mock_server.datasources.publish.return_value = create_mock_object( + "published_datasource", 1337, spec=DatasourceItem + ) + + yield TableauServer( + url="https://tableau.domain.com", user="user", password="pass", site_id="site", project_id="1" + ) + + @pytest.fixture() + def hyper_file(self, data_path): + return f"{data_path}/readers/hyper_file/dummy.hyper" + + def test_working_project_w_project_id(self, server): + server.project_id = "3" + wp = server.working_project + assert wp.id == "3" and wp.name == "project-3" + + def test_working_project_w_project_name(self, server): + server.project_id = None + server.project = "project-5" + wp = server.working_project + assert wp.id == "10" and wp.name == "project-5" + + def test_working_project_w_project_name_and_parent_project(self, server): + server.project_id = None + server.project = "project-4" + server.parent_project = "parent-project-1" + wp = server.working_project + assert wp.id == "4" and wp.name == "project-4" + + def test_working_project_w_project_name_and_root(self, server): + server.project_id = None + server.project = "parent-project-1" + server.parent_project = "root" + wp = server.working_project + assert wp.id == "1" and wp.name == "parent-project-1" + + def test_working_project_multiple_projects(self, server): + with pytest.raises(ValueError): + server.project_id = None + server.project = "project-4" + server.working_project + + def test_working_project_unknown(self, server): + with pytest.raises(ValueError): + server.project_id = None + server.project = "project-6" + server.working_project + + def test_publish_hyper(self, server, hyper_file): + p = TableauHyperPublisher( + url="https://tableau.domain.com", + user="user", + password="pass", + site_id="site", + project_id="1", + hyper_path=hyper_file, + datasource_name="published_datasource", + ) + p.publish() + assert p.output.datasource_item.id == "1337"