Skip to content

Commit

Permalink
feat: per-process ingest connections (#1058)
Browse files Browse the repository at this point in the history
* adds per process connections for Google Drive connector
  • Loading branch information
ryannikolaidis authored Aug 17, 2023
1 parent dd0f582 commit 668d0f1
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 19 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
## 0.10.2
## 0.10.3

### Enhancements
* Bump unstructured-inference==0.5.13:
- Fix extracted image elements being included in layout merge, addresses the issue
where an entire-page image in a PDF was not passed to the layout model when using hi_res.
* Adds ability to reuse connections per process in unstructured-ingest

### Features

Expand Down
39 changes: 39 additions & 0 deletions test_unstructured_ingest/unit/doc_processor/test_generalized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from dataclasses import dataclass

import pytest

from unstructured.ingest.doc_processor.generalized import (
process_document,
)
from unstructured.ingest.interfaces import BaseIngestDoc, IngestDocSessionHandleMixin


@dataclass
class IngestDocWithSessionHandle(IngestDocSessionHandleMixin, BaseIngestDoc):
pass

def test_process_document_with_session_handle(mocker):
"""Test that the process_document function calls the doc_processor_fn with the correct
arguments, assigns the session handle, and returns the correct results."""
mock_session_handle = mocker.MagicMock()
mocker.patch("unstructured.ingest.doc_processor.generalized.session_handle", mock_session_handle)
mock_doc = mocker.MagicMock(spec=(IngestDocWithSessionHandle))

result = process_document(mock_doc)

mock_doc.get_file.assert_called_once_with()
mock_doc.write_result.assert_called_with()
mock_doc.cleanup_file.assert_called_once_with()
assert result == mock_doc.process_file.return_value
assert mock_doc.session_handle == mock_session_handle


def test_process_document_no_session_handle(mocker):
"""Test that the process_document function calls does not assign session handle the IngestDoc
does not have the session handle mixin."""
mocker.patch("unstructured.ingest.doc_processor.generalized.session_handle", mocker.MagicMock())
mock_doc = mocker.MagicMock(spec=(BaseIngestDoc))

process_document(mock_doc)

assert not hasattr(mock_doc, "session_handle")
5 changes: 0 additions & 5 deletions test_unstructured_ingest/unit/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def test_partition_file():
assert data_source_metadata["date_processed"] == TEST_DATE_PROCESSSED


@freeze_time(TEST_DATE_PROCESSSED)
def test_process_file_fields_include_default(mocker, partition_test_results):
"""Validate when metadata_include and metadata_exclude are not set, all fields:
("element_id", "text", "type", "metadata") are included"""
Expand All @@ -162,10 +161,6 @@ def test_process_file_fields_include_default(mocker, partition_test_results):
isd_elems = test_ingest_doc.process_file()
assert len(isd_elems)
assert mock_partition.call_count == 1
assert (
mock_partition.call_args.kwargs["data_source_metadata"].date_processed
== TEST_DATE_PROCESSSED
)
for elem in isd_elems:
assert {"element_id", "text", "type", "metadata"} == set(elem.keys())
data_source_metadata = elem["metadata"]["data_source"]
Expand Down
2 changes: 1 addition & 1 deletion unstructured/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.10.2" # pragma: no cover
__version__ = "0.10.3" # pragma: no cover
35 changes: 24 additions & 11 deletions unstructured/ingest/connector/google_drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,36 @@
from dataclasses import dataclass
from mimetypes import guess_extension
from pathlib import Path
from typing import Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional

from unstructured.file_utils.filetype import EXT_TO_FILETYPE
from unstructured.file_utils.google_filetype import GOOGLE_DRIVE_EXPORT_TYPES
from unstructured.ingest.interfaces import (
BaseConnector,
BaseConnectorConfig,
BaseIngestDoc,
BaseSessionHandle,
ConfigSessionHandleMixin,
ConnectorCleanupMixin,
IngestDocCleanupMixin,
IngestDocSessionHandleMixin,
StandardConnectorConfig,
)
from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies

if TYPE_CHECKING:
from googleapiclient.discovery import Resource as GoogleAPIResource

FILE_FORMAT = "{id}-{name}{ext}"
DIRECTORY_FORMAT = "{id}-{name}"


@dataclass
class GoogleDriveSessionHandle(BaseSessionHandle):
service: "GoogleAPIResource"


@requires_dependencies(["googleapiclient"], extras="google-drive")
def create_service_account_object(key_path, id=None):
"""
Expand Down Expand Up @@ -65,7 +76,7 @@ def create_service_account_object(key_path, id=None):


@dataclass
class SimpleGoogleDriveConfig(BaseConnectorConfig):
class SimpleGoogleDriveConfig(ConfigSessionHandleMixin, BaseConnectorConfig):
"""Connector config where drive_id is the id of the document to process or
the folder to process all documents from."""

Expand All @@ -81,11 +92,16 @@ def __post_init__(self):
f"Extension not supported. "
f"Value MUST be one of {', '.join([k for k in EXT_TO_FILETYPE if k is not None])}.",
)
self.service = create_service_account_object(self.service_account_key, self.drive_id)

def create_session_handle(
self,
) -> GoogleDriveSessionHandle:
service = create_service_account_object(self.service_account_key)
return GoogleDriveSessionHandle(service=service)


@dataclass
class GoogleDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
class GoogleDriveIngestDoc(IngestDocSessionHandleMixin, IngestDocCleanupMixin, BaseIngestDoc):
config: SimpleGoogleDriveConfig
file_meta: Dict

Expand All @@ -103,8 +119,6 @@ def get_file(self):
from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseDownload

self.config.service = create_service_account_object(self.config.service_account_key)

if self.file_meta.get("mimeType", "").startswith("application/vnd.google-apps"):
export_mime = GOOGLE_DRIVE_EXPORT_TYPES.get(
self.file_meta.get("mimeType"), # type: ignore
Expand All @@ -117,12 +131,12 @@ def get_file(self):
)
return

request = self.config.service.files().export_media(
request = self.session_handle.service.files().export_media(
fileId=self.file_meta.get("id"),
mimeType=export_mime,
)
else:
request = self.config.service.files().get_media(fileId=self.file_meta.get("id"))
request = self.session_handle.service.files().get_media(fileId=self.file_meta.get("id"))
file = io.BytesIO()
downloader = MediaIoBaseDownload(file, request)
downloaded = False
Expand Down Expand Up @@ -170,12 +184,13 @@ def __init__(self, standard_config: StandardConnectorConfig, config: SimpleGoogl

def _list_objects(self, drive_id, recursive=False):
files = []
service = self.config.create_session_handle().service

def traverse(drive_id, download_dir, output_dir, recursive=False):
page_token = None
while True:
response = (
self.config.service.files()
service.files()
.list(
spaces="drive",
fields="nextPageToken, files(id, name, mimeType)",
Expand Down Expand Up @@ -244,6 +259,4 @@ def initialize(self):

def get_ingest_docs(self):
files = self._list_objects(self.config.drive_id, self.config.recursive)
# Setting to None because service object can't be pickled for multiprocessing.
self.config.service = None
return [GoogleDriveIngestDoc(self.standard_config, self.config, file) for file in files]
17 changes: 16 additions & 1 deletion unstructured/ingest/doc_processor/generalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@
from unstructured_inference.models.base import get_model

from unstructured.ingest.interfaces import BaseIngestDoc as IngestDoc
from unstructured.ingest.interfaces import (
BaseSessionHandle,
IngestDocSessionHandleMixin,
)
from unstructured.ingest.logger import logger

# module-level variable to store session handle
session_handle: Optional[BaseSessionHandle] = None


def initialize():
"""Download default model or model specified by UNSTRUCTURED_HI_RES_MODEL_NAME environment
Expand All @@ -30,16 +37,24 @@ def process_document(doc: "IngestDoc", **partition_kwargs) -> Optional[List[Dict
partition_kwargs
ultimately the parameters passed to partition()
"""
global session_handle
isd_elems_no_filename = None
try:
if isinstance(doc, IngestDocSessionHandleMixin):
if session_handle is None:
# create via doc.session_handle, which is a property that creates a
# session handle if one is not already defined
session_handle = doc.session_handle
else:
doc.session_handle = session_handle
# does the work necessary to load file into filesystem
# in the future, get_file_handle() could also be supported
doc.get_file()

isd_elems_no_filename = doc.process_file(**partition_kwargs)

# Note, this may be a no-op if the IngestDoc doesn't do anything to persist
# the results. Instead, the MainProcess (caller) may work with the aggregate
# the results. Instead, the Processor (caller) may work with the aggregate
# results across all docs in memory.
doc.write_result()
except Exception:
Expand Down
29 changes: 29 additions & 0 deletions unstructured/ingest/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
from unstructured.staging.base import convert_to_dict


@dataclass
class BaseSessionHandle(ABC):
"""Abstract Base Class for sharing resources that are local to an individual process.
e.g., a connection for making a request for fetching documents."""


@dataclass
class ProcessorConfigs:
"""Common set of config required when running data connectors."""
Expand Down Expand Up @@ -330,3 +336,26 @@ def cleanup_file(self):
):
logger.debug(f"Cleaning up {self}")
os.unlink(self.filename)


class ConfigSessionHandleMixin:
@abstractmethod
def create_session_handle(self) -> BaseSessionHandle:
"""Creates a session handle that will be assigned on each IngestDoc to share
session related resources across all document handling for a given subprocess."""


class IngestDocSessionHandleMixin:
config: ConfigSessionHandleMixin
_session_handle: Optional[BaseSessionHandle] = None

@property
def session_handle(self):
"""If a session handle is not assigned, creates a new one and assigns it."""
if self._session_handle is None:
self._session_handle = self.config.create_session_handle()
return self._session_handle

@session_handle.setter
def session_handle(self, session_handle: BaseSessionHandle):
self._session_handle = session_handle

0 comments on commit 668d0f1

Please sign in to comment.