Skip to content

Commit

Permalink
Add store_tmpfile option to job_result_format_each() method to JobAPI…
Browse files Browse the repository at this point in the history
… and Client (#120)

* Add store_tmpfile option to job_result_format_each() method to JobAPI and Client
  • Loading branch information
chezou authored Sep 6, 2024
1 parent b6da038 commit 649d69f
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 66 deletions.
4 changes: 2 additions & 2 deletions tdclient/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
retry_post_requests=False,
max_cumul_retry_delay=600,
http_proxy=None,
**kwargs
**kwargs,
):
headers = {} if headers is None else headers
if apikey is not None:
Expand Down Expand Up @@ -600,7 +600,7 @@ def _read_csv_file(
encoding="utf-8",
dtypes=None,
converters=None,
**kwargs
**kwargs,
):
if columns is None:
reader = csv_dict_record_reader(file_like, encoding, dialect)
Expand Down
30 changes: 21 additions & 9 deletions tdclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@


class Client:
"""API Client for Treasure Data Service
"""
"""API Client for Treasure Data Service"""

def __init__(self, *args, **kwargs):
self._api = api.API(*args, **kwargs)
Expand Down Expand Up @@ -79,7 +78,7 @@ def database(self, db_name):
:class:`tdclient.models.Database`
"""
databases = self.api.list_databases()
for (name, kwargs) in databases.items():
for name, kwargs in databases.items():
if name == db_name:
return models.Database(self, name, **kwargs)
raise api.NotFoundError("Database '%s' does not exist" % (db_name))
Expand Down Expand Up @@ -229,7 +228,7 @@ def query(
priority=None,
retry_limit=None,
type="hive",
**kwargs
**kwargs,
):
"""Run a query on specified database table.
Expand Down Expand Up @@ -258,7 +257,7 @@ def query(
result_url=result_url,
priority=priority,
retry_limit=retry_limit,
**kwargs
**kwargs,
)
return models.Job(self, job_id, type, q)

Expand Down Expand Up @@ -334,16 +333,30 @@ def job_result_format(self, job_id, format, header=False):
"""
return self.api.job_result_format(job_id, format, header=header)

def job_result_format_each(self, job_id, format, header=False):
def job_result_format_each(
self, job_id, format, header=False, store_tmpfile=False, num_threads=4
):
"""
Args:
job_id (str): job id
format (str): output format of result set
header (bool, optional): include header in the result set. Default: False
store_tmpfile (bool, optional): store result to a temporary file.
Works only when fmt is "msgpack". Default is False.
num_threads (int, optional): number of threads to download result.
Works only when store_tmpfile is True. Default is 4.
Returns:
an iterator of rows in result set
"""
for row in self.api.job_result_format_each(job_id, format, header=header):
for row in self.api.job_result_format_each(
job_id,
format,
header=header,
store_tmpfile=store_tmpfile,
num_threads=num_threads,
):
yield row

def download_job_result(self, job_id, path, num_threads=4):
Expand Down Expand Up @@ -940,8 +953,7 @@ def remove_apikey(self, name, apikey):
return self.api.remove_apikey(name, apikey)

def close(self):
"""Close opened API connections.
"""
"""Close opened API connections."""
return self._api.close()


Expand Down
46 changes: 38 additions & 8 deletions tdclient/job_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#!/usr/bin/env python

import codecs
import gzip
import json
import logging
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor

import msgpack
Expand Down Expand Up @@ -214,7 +216,7 @@ def job_result_format(self, job_id, format, header=False):
job_id (int): Job ID
format (str): Output format of the job result information.
"json" or "msgpack"
header (boolean): Includes Header or not.
header (boolean): Includes Header or not.
False or True
Returns:
Expand All @@ -225,15 +227,22 @@ def job_result_format(self, job_id, format, header=False):
result.append(row)
return result

def job_result_format_each(self, job_id, format, header=False):
def job_result_format_each(
self, job_id, format, header=False, store_tmpfile=False, num_threads=4
):
"""Yield a row of the job result with specified format.
Args:
job_id (int): job ID
format (str): Output format of the job result information.
format (str): Output format of the job result information.
"json" or "msgpack"
header (bool): Include Header info or not
"True" or "False"
store_tmpfile (bool): Download job result as a temporary file or not. Default is False.
It works only when format is "msgpack".
"True" or "False"
num_threads (int): Number of threads to download the job result when store_tmpfile is True.
Default is 4.
Yields:
The query result of the specified job in.
"""
Expand All @@ -246,16 +255,37 @@ def job_result_format_each(self, job_id, format, header=False):
if format != "msgpack":
format = "json"

if store_tmpfile:
if format != "msgpack":
raise ValueError("store_tmpfile works only when format is msgpack")

with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, f"{job_id}.msgpack.gz")
self.download_job_result(job_id, path, num_threads)
with gzip.GzipFile(path, "rb") as f:
unpacker = msgpack.Unpacker(
f, raw=False, max_buffer_size=1000 * 1024**2
)
for row in unpacker:
yield row
return

with self.get(
create_url("/v3/job/result/{job_id}?format={format}&header={header}",
job_id=job_id, format=format, header=header)
create_url(
"/v3/job/result/{job_id}?format={format}&header={header}",
job_id=job_id,
format=format,
header=header,
)
) as res:
code = res.status
if code != 200:
self.raise_error("Get job result failed", res, "")
if format == "msgpack":
unpacker = msgpack.Unpacker(raw=False, max_buffer_size=1000 * 1024 ** 2)
for chunk in res.stream(1024 ** 2):
unpacker = msgpack.Unpacker(
raw=False, max_buffer_size=1000 * 1024**2
)
for chunk in res.stream(1024**2):
unpacker.feed(chunk)
for row in unpacker:
yield row
Expand Down Expand Up @@ -354,7 +384,7 @@ def query(
result_url=None,
priority=None,
retry_limit=None,
**kwargs
**kwargs,
):
"""Create a job for given query.
Expand Down
76 changes: 32 additions & 44 deletions tdclient/job_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@


class Schema:
"""Schema of a database table on Treasure Data Service
"""
"""Schema of a database table on Treasure Data Service"""

class Field:
def __init__(self, name, type):
Expand Down Expand Up @@ -48,8 +47,7 @@ def add_field(self, name, type):


class Job(Model):
"""Job on Treasure Data Service
"""
"""Job on Treasure Data Service"""

STATUS_QUEUED = "queued"
STATUS_BOOTING = "booting"
Expand Down Expand Up @@ -92,8 +90,7 @@ def _feed(self, data=None):
self._result_export_target_job_id = data.get("result_export_target_job_id")

def update(self):
"""Update all fields of the job
"""
"""Update all fields of the job"""
data = self._client.api.show_job(self._job_id)
self._feed(data)

Expand All @@ -105,57 +102,48 @@ def _update_status(self):
self.update()

def _update_progress(self):
"""Update `_status` field of the job if it's not finished
"""
"""Update `_status` field of the job if it's not finished"""
if self._status not in self.FINISHED_STATUS:
self._status = self._client.job_status(self._job_id)

@property
def id(self):
"""a string represents the identifier of the job
"""
"""a string represents the identifier of the job"""
return self._job_id

@property
def job_id(self):
"""a string represents the identifier of the job
"""
"""a string represents the identifier of the job"""
return self._job_id

@property
def type(self):
"""a string represents the engine type of the job (e.g. "hive", "presto", etc.)
"""
"""a string represents the engine type of the job (e.g. "hive", "presto", etc.)"""
return self._type

@property
def result_size(self):
"""the length of job result
"""
"""the length of job result"""
return self._result_size

@property
def num_records(self):
"""the number of records of job result
"""
"""the number of records of job result"""
return self._num_records

@property
def result_url(self):
"""a string of URL of the result on Treasure Data Service
"""
"""a string of URL of the result on Treasure Data Service"""
return self._result_url

@property
def result_schema(self):
"""an array of array represents the type of result columns (Hive specific) (e.g. [["_c1", "string"], ["_c2", "bigint"]])
"""
"""an array of array represents the type of result columns (Hive specific) (e.g. [["_c1", "string"], ["_c2", "bigint"]])"""
return self._hive_result_schema

@property
def priority(self):
"""a string represents the priority of the job (e.g. "NORMAL", "HIGH", etc.)
"""
"""a string represents the priority of the job (e.g. "NORMAL", "HIGH", etc.)"""
if self._priority in self.JOB_PRIORITY:
return self.JOB_PRIORITY[self._priority]
else:
Expand All @@ -164,44 +152,37 @@ def priority(self):

@property
def retry_limit(self):
"""a number for automatic retry count
"""
"""a number for automatic retry count"""
return self._retry_limit

@property
def org_name(self):
"""organization name
"""
"""organization name"""
return self._org_name

@property
def user_name(self):
"""executing user name
"""
"""executing user name"""
return self._user_name

@property
def database(self):
"""a string represents the name of a database that job is running on
"""
"""a string represents the name of a database that job is running on"""
return self._database

@property
def linked_result_export_job_id(self):
"""Linked result export job ID from query job
"""
"""Linked result export job ID from query job"""
return self._linked_result_export_job_id

@property
def result_export_target_job_id(self):
"""Associated query job ID from result export job ID
"""
"""Associated query job ID from result export job ID"""
return self._result_export_target_job_id

@property
def debug(self):
"""a :class:`dict` of debug output (e.g. "cmdout", "stderr")
"""
"""a :class:`dict` of debug output (e.g. "cmdout", "stderr")"""
return self._debug

def wait(self, timeout=None, wait_interval=5, wait_callback=None):
Expand Down Expand Up @@ -235,8 +216,7 @@ def kill(self):

@property
def query(self):
"""a string represents the query string of the job
"""
"""a string represents the query string of the job"""
return self._query

def status(self):
Expand All @@ -250,8 +230,7 @@ def status(self):

@property
def url(self):
"""a string of URL of the job on Treasure Data Service
"""
"""a string of URL of the job on Treasure Data Service"""
return self._url

def result(self):
Expand All @@ -270,10 +249,14 @@ def result(self):
for row in self._result:
yield row

def result_format(self, fmt):
def result_format(self, fmt, store_tmpfile=False, num_threads=4):
"""
Args:
fmt (str): output format of result set
store_tmpfile (bool, optional): store result to a temporary file.
Works only when fmt is "msgpack". Default is False.
num_threads (int, optional): number of threads to download result.
Works only when store_tmpfile is True. Default is 4.
Yields:
an iterator of rows in result set
Expand All @@ -283,7 +266,12 @@ def result_format(self, fmt):
else:
self.update()
if self._result is None:
for row in self._client.job_result_format_each(self._job_id, fmt):
for row in self._client.job_result_format_each(
self._job_id,
fmt,
store_tmpfile=store_tmpfile,
num_threads=num_threads,
):
yield row
else:
for row in self._result:
Expand Down
Loading

0 comments on commit 649d69f

Please sign in to comment.