Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add database support for data transformation tool #119

Merged
merged 11 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion ods_tools/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
ModelSettingSchema,
AnalysisSettingSchema,
)
from ods_tools.odtf.controller import transform_format
try:
from ods_tools.odtf.controller import transform_format
except ImportError:
logger.info("Data transformation package requirements not intalled.")


def get_oed_exposure(config_json=None, oed_dir=None, **kwargs):
Expand Down Expand Up @@ -108,6 +111,9 @@ def transform(**kwargs):
except OdsException as e:
logger.error("Transformation failed:")
logger.error(e)
except NameError as e:
logger.error("Data transformation package requirements not intalled.")
logger.error(e)


command_action = {
Expand Down
14 changes: 11 additions & 3 deletions ods_tools/odtf/connector/csv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import csv
from typing import Any, Dict, Iterable

import pandas as pd

from .base import BaseConnector
from ..notset import NotSetType

Expand Down Expand Up @@ -95,6 +97,12 @@ def load(self, data: Iterable[Dict[str, Any]]):
writer.writerow(self._data_serializer(first_row))
writer.writerows(map(self._data_serializer, data))

def extract(self) -> Iterable[Dict[str, Any]]:
with open(self.file_path, "r") as f:
yield from csv.DictReader(f, quoting=self.quoting)
def fetch_data(self, chunksize: int) -> Iterable[pd.DataFrame]:
"""
Fetch data from the csv file in batches.

:param chunksize: Number of rows per batch
:return: Iterable of data batches as pandas DataFrames
"""
for batch in pd.read_csv(self.file_path, chunksize=chunksize, low_memory=False):
yield batch
10 changes: 10 additions & 0 deletions ods_tools/odtf/connector/db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .mssql import SQLServerConnector
from .postgres import PostgresConnector
from .sqlite import SQLiteConnector


__all__ = [
"SQLiteConnector",
"PostgresConnector",
"SQLServerConnector",
]
169 changes: 169 additions & 0 deletions ods_tools/odtf/connector/db/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from typing import Any, Dict, Iterable, List

import sqlparams
import sqlparse

from ods_tools.odtf.connector import BaseConnector

from .errors import DBQueryError


class BaseDBConnector(BaseConnector):
"""
Connects to a database for reading and writing data.

**Options:**

* `host` - Which host to use when connecting to the database
* `port` - The port to use when connecting to the database
* `database` - The database name or relative path to the file for sqlite3
* `user` - The username to use when connecting to the database
* `password` - The password to use when connecting to the database
* `select_statement` - sql query to read the data from
* `insert_statement` - sql query to insert the data from
"""

name = "BaseDB Connector"
options_schema = {
"type": "object",
"properties": {
"host": {
"type": "string",
"description": (
"Which host to use when connecting to the database. "
"Not used with SQLite."
),
"default": "",
"title": "Host",
},
"port": {
"type": "string",
"description": (
"The port to use when connecting to the database. "
"Not used with SQLite."
),
"default": "",
"title": "Port",
},
"database": {
"type": "string",
"description": (
"The database name or relative path to the file for "
"sqlite3"
),
"title": "Database",
},
"user": {
"type": "string",
"description": (
"The username to use when connecting to the database. "
"Not used with SQLite."
),
"default": "",
"title": "User",
},
"password": {
"type": "password",
"description": (
"The password to use when connecting to the database. "
"Not used with SQLite."
),
"default": "",
"title": "Password",
},
"sql_statement": {
"type": "string",
"description": "The path to the file which contains the "
"sql statement to run",
"subtype": "path",
"title": "Select Statement File",
},
},
"required": ["database", "select_statement", "insert_statement"],
}
sql_params_output = "qmark"

def __init__(self, config, **options):
super().__init__(config, **options)

self.database = {
"host": options.get("host", ""),
"port": options.get("port", ""),
"database": options["database"],
"user": options.get("user", ""),
"password": options.get("password", ""),
}
self.sql_statement_path = config.absolute_path(
options["sql_statement"]
)

def _create_connection(self, database: Dict[str, str]):
raise NotImplementedError()

def _get_cursor(self, conn):
cur = conn.cursor()
return cur

def _get_select_statement(self) -> str:
"""
SQL string to select the data from the DB

:return: string
"""
with open(self.sql_statement_path) as f:
select_statement = f.read()

return select_statement

def _get_insert_statements(self) -> List[str]:
"""
SQL string(s) to insert the data into the DB

:return: List of sql statements
"""
with open(self.sql_statement_path) as f:
sql = f.read()

return sqlparse.split(sql)

def load(self, data: Iterable[Dict[str, Any]]):
insert_sql = self._get_insert_statements()
data = list(
data
) # convert iterable to list as we reuse it based on number of queries
conn = self._create_connection(self.database)

with conn:
cur = self._get_cursor(conn)
query = sqlparams.SQLParams("named", self.sql_params_output)

# insert query can contain more than 1 statement
for line in insert_sql:
sql, params = query.formatmany(line, data)
try:
cur.executemany(sql, params)
except Exception as e:
raise DBQueryError(sql, e, data=data)

def row_to_dict(self, row):
"""
Convert the row returned from the cursor into a dictionary

:return: Dict
"""
return dict(row)

def extract(self) -> Iterable[Dict[str, Any]]:
select_sql = self._get_select_statement()
conn = self._create_connection(self.database)

with conn:
cur = self._get_cursor(conn)
try:
cur.execute(select_sql)
except Exception as e:
raise DBQueryError(select_sql, e)

rows = cur.fetchall()
for row in rows:
yield self.row_to_dict(row)
18 changes: 18 additions & 0 deletions ods_tools/odtf/connector/db/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from ...errors import ConverterError


class DBConnectionError(ConverterError):
pass


class DBQueryError(ConverterError):
def __init__(self, query, error, data=None):
self.query = query
self.data = data
self.error = error

super().__init__(f"Error running query: {query} with {data} - {error}")


class DBInsertDataError(ConverterError):
pass
56 changes: 56 additions & 0 deletions ods_tools/odtf/connector/db/mssql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Dict

import pandas as pd
import pyodbc

from .base import BaseDBConnector
from .errors import DBConnectionError


class SQLServerConnector(BaseDBConnector):
"""
Connects to an Microsoft SQL Server for reading and writing data.
"""

name = "SQL Server Connector"
driver = "{ODBC Driver 17 for SQL Server}"

def _create_connection(self, database: Dict[str, str]):
"""
Create database connection to the SQLite database specified in database
:param database: Dict object with connection info

:return: Connection object
"""

try:
conn = pyodbc.connect(
"DRIVER={};SERVER={};PORT={};DATABASE={};UID={};PWD={}".format(
self.driver,
database["host"],
database["port"],
database["database"],
database["user"],
database["password"],
)
)
except Exception:
raise DBConnectionError()

return conn

def fetch_data(self, batch_size: int):
"""
Fetch data from the database in batches.

:param batch_size: Number of rows per batch

:yield: Data batches as pandas DataFrames
"""

with open(self.sql_statement_path, 'r') as file:
sql_query = file.read()

with self._create_connection(self.database) as conn:
for batch in pd.read_sql(sql_query, conn, chunksize=batch_size):
yield batch
47 changes: 47 additions & 0 deletions ods_tools/odtf/connector/db/postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Dict

import pandas as pd
import psycopg2
import psycopg2.extras

from .base import BaseDBConnector
from .errors import DBConnectionError


class PostgresConnector(BaseDBConnector):
"""
Connects to a Postgres database for reading and writing data.
"""

name = "Postgres Connector"
sql_params_output = "pyformat"

def _create_connection(self, database: Dict[str, str]):
"""
Create database connection to the Postgres database
:param database: Dict with database connection settings

:return: Connection object
"""
try:
conn = psycopg2.connect(**database)
except Exception as e:
raise DBConnectionError(e)

return conn

def fetch_data(self, batch_size: int):
"""
Fetch data from the database in batches.

:param batch_size: Number of rows per batch

:yield: Data batches as pandas DataFrames
"""

with open(self.sql_statement_path, 'r') as file:
sql_query = file.read()

with self._create_connection(self.database) as conn:
for batch in pd.read_sql(sql_query, conn, chunksize=batch_size):
yield batch
Loading
Loading