From 668d0f1b01d28cc5cf7e30381ce12a0114a221ae Mon Sep 17 00:00:00 2001 From: ryannikolaidis <1208590+ryannikolaidis@users.noreply.github.com> Date: Thu, 17 Aug 2023 10:34:08 -0700 Subject: [PATCH] feat: per-process ingest connections (#1058) * adds per process connections for Google Drive connector --- CHANGELOG.md | 3 +- .../unit/doc_processor/test_generalized.py | 39 +++++++++++++++++++ .../unit/test_interfaces.py | 5 --- unstructured/__version__.py | 2 +- unstructured/ingest/connector/google_drive.py | 35 +++++++++++------ .../ingest/doc_processor/generalized.py | 17 +++++++- unstructured/ingest/interfaces.py | 29 ++++++++++++++ 7 files changed, 111 insertions(+), 19 deletions(-) create mode 100644 test_unstructured_ingest/unit/doc_processor/test_generalized.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 105b68f52a..16f15457fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,10 @@ -## 0.10.2 +## 0.10.3 ### Enhancements * Bump unstructured-inference==0.5.13: - Fix extracted image elements being included in layout merge, addresses the issue where an entire-page image in a PDF was not passed to the layout model when using hi_res. +* Adds ability to reuse connections per process in unstructured-ingest ### Features diff --git a/test_unstructured_ingest/unit/doc_processor/test_generalized.py b/test_unstructured_ingest/unit/doc_processor/test_generalized.py new file mode 100644 index 0000000000..1343d526a9 --- /dev/null +++ b/test_unstructured_ingest/unit/doc_processor/test_generalized.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass + +import pytest + +from unstructured.ingest.doc_processor.generalized import ( + process_document, +) +from unstructured.ingest.interfaces import BaseIngestDoc, IngestDocSessionHandleMixin + + +@dataclass +class IngestDocWithSessionHandle(IngestDocSessionHandleMixin, BaseIngestDoc): + pass + +def test_process_document_with_session_handle(mocker): + """Test that the process_document function calls the doc_processor_fn with the correct + arguments, assigns the session handle, and returns the correct results.""" + mock_session_handle = mocker.MagicMock() + mocker.patch("unstructured.ingest.doc_processor.generalized.session_handle", mock_session_handle) + mock_doc = mocker.MagicMock(spec=(IngestDocWithSessionHandle)) + + result = process_document(mock_doc) + + mock_doc.get_file.assert_called_once_with() + mock_doc.write_result.assert_called_with() + mock_doc.cleanup_file.assert_called_once_with() + assert result == mock_doc.process_file.return_value + assert mock_doc.session_handle == mock_session_handle + + +def test_process_document_no_session_handle(mocker): + """Test that the process_document function calls does not assign session handle the IngestDoc + does not have the session handle mixin.""" + mocker.patch("unstructured.ingest.doc_processor.generalized.session_handle", mocker.MagicMock()) + mock_doc = mocker.MagicMock(spec=(BaseIngestDoc)) + + process_document(mock_doc) + + assert not hasattr(mock_doc, "session_handle") diff --git a/test_unstructured_ingest/unit/test_interfaces.py b/test_unstructured_ingest/unit/test_interfaces.py index 2dacd4161a..c00f2cb348 100644 --- a/test_unstructured_ingest/unit/test_interfaces.py +++ b/test_unstructured_ingest/unit/test_interfaces.py @@ -144,7 +144,6 @@ def test_partition_file(): assert data_source_metadata["date_processed"] == TEST_DATE_PROCESSSED -@freeze_time(TEST_DATE_PROCESSSED) def test_process_file_fields_include_default(mocker, partition_test_results): """Validate when metadata_include and metadata_exclude are not set, all fields: ("element_id", "text", "type", "metadata") are included""" @@ -162,10 +161,6 @@ def test_process_file_fields_include_default(mocker, partition_test_results): isd_elems = test_ingest_doc.process_file() assert len(isd_elems) assert mock_partition.call_count == 1 - assert ( - mock_partition.call_args.kwargs["data_source_metadata"].date_processed - == TEST_DATE_PROCESSSED - ) for elem in isd_elems: assert {"element_id", "text", "type", "metadata"} == set(elem.keys()) data_source_metadata = elem["metadata"]["data_source"] diff --git a/unstructured/__version__.py b/unstructured/__version__.py index 563e2bd6bd..c6adbb089b 100644 --- a/unstructured/__version__.py +++ b/unstructured/__version__.py @@ -1 +1 @@ -__version__ = "0.10.2" # pragma: no cover +__version__ = "0.10.3" # pragma: no cover diff --git a/unstructured/ingest/connector/google_drive.py b/unstructured/ingest/connector/google_drive.py index 02b5a07be3..7053a032eb 100644 --- a/unstructured/ingest/connector/google_drive.py +++ b/unstructured/ingest/connector/google_drive.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from mimetypes import guess_extension from pathlib import Path -from typing import Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional from unstructured.file_utils.filetype import EXT_TO_FILETYPE from unstructured.file_utils.google_filetype import GOOGLE_DRIVE_EXPORT_TYPES @@ -12,17 +12,28 @@ BaseConnector, BaseConnectorConfig, BaseIngestDoc, + BaseSessionHandle, + ConfigSessionHandleMixin, ConnectorCleanupMixin, IngestDocCleanupMixin, + IngestDocSessionHandleMixin, StandardConnectorConfig, ) from unstructured.ingest.logger import logger from unstructured.utils import requires_dependencies +if TYPE_CHECKING: + from googleapiclient.discovery import Resource as GoogleAPIResource + FILE_FORMAT = "{id}-{name}{ext}" DIRECTORY_FORMAT = "{id}-{name}" +@dataclass +class GoogleDriveSessionHandle(BaseSessionHandle): + service: "GoogleAPIResource" + + @requires_dependencies(["googleapiclient"], extras="google-drive") def create_service_account_object(key_path, id=None): """ @@ -65,7 +76,7 @@ def create_service_account_object(key_path, id=None): @dataclass -class SimpleGoogleDriveConfig(BaseConnectorConfig): +class SimpleGoogleDriveConfig(ConfigSessionHandleMixin, BaseConnectorConfig): """Connector config where drive_id is the id of the document to process or the folder to process all documents from.""" @@ -81,11 +92,16 @@ def __post_init__(self): f"Extension not supported. " f"Value MUST be one of {', '.join([k for k in EXT_TO_FILETYPE if k is not None])}.", ) - self.service = create_service_account_object(self.service_account_key, self.drive_id) + + def create_session_handle( + self, + ) -> GoogleDriveSessionHandle: + service = create_service_account_object(self.service_account_key) + return GoogleDriveSessionHandle(service=service) @dataclass -class GoogleDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc): +class GoogleDriveIngestDoc(IngestDocSessionHandleMixin, IngestDocCleanupMixin, BaseIngestDoc): config: SimpleGoogleDriveConfig file_meta: Dict @@ -103,8 +119,6 @@ def get_file(self): from googleapiclient.errors import HttpError from googleapiclient.http import MediaIoBaseDownload - self.config.service = create_service_account_object(self.config.service_account_key) - if self.file_meta.get("mimeType", "").startswith("application/vnd.google-apps"): export_mime = GOOGLE_DRIVE_EXPORT_TYPES.get( self.file_meta.get("mimeType"), # type: ignore @@ -117,12 +131,12 @@ def get_file(self): ) return - request = self.config.service.files().export_media( + request = self.session_handle.service.files().export_media( fileId=self.file_meta.get("id"), mimeType=export_mime, ) else: - request = self.config.service.files().get_media(fileId=self.file_meta.get("id")) + request = self.session_handle.service.files().get_media(fileId=self.file_meta.get("id")) file = io.BytesIO() downloader = MediaIoBaseDownload(file, request) downloaded = False @@ -170,12 +184,13 @@ def __init__(self, standard_config: StandardConnectorConfig, config: SimpleGoogl def _list_objects(self, drive_id, recursive=False): files = [] + service = self.config.create_session_handle().service def traverse(drive_id, download_dir, output_dir, recursive=False): page_token = None while True: response = ( - self.config.service.files() + service.files() .list( spaces="drive", fields="nextPageToken, files(id, name, mimeType)", @@ -244,6 +259,4 @@ def initialize(self): def get_ingest_docs(self): files = self._list_objects(self.config.drive_id, self.config.recursive) - # Setting to None because service object can't be pickled for multiprocessing. - self.config.service = None return [GoogleDriveIngestDoc(self.standard_config, self.config, file) for file in files] diff --git a/unstructured/ingest/doc_processor/generalized.py b/unstructured/ingest/doc_processor/generalized.py index 5fcef9f8c6..243d465678 100644 --- a/unstructured/ingest/doc_processor/generalized.py +++ b/unstructured/ingest/doc_processor/generalized.py @@ -6,8 +6,15 @@ from unstructured_inference.models.base import get_model from unstructured.ingest.interfaces import BaseIngestDoc as IngestDoc +from unstructured.ingest.interfaces import ( + BaseSessionHandle, + IngestDocSessionHandleMixin, +) from unstructured.ingest.logger import logger +# module-level variable to store session handle +session_handle: Optional[BaseSessionHandle] = None + def initialize(): """Download default model or model specified by UNSTRUCTURED_HI_RES_MODEL_NAME environment @@ -30,8 +37,16 @@ def process_document(doc: "IngestDoc", **partition_kwargs) -> Optional[List[Dict partition_kwargs ultimately the parameters passed to partition() """ + global session_handle isd_elems_no_filename = None try: + if isinstance(doc, IngestDocSessionHandleMixin): + if session_handle is None: + # create via doc.session_handle, which is a property that creates a + # session handle if one is not already defined + session_handle = doc.session_handle + else: + doc.session_handle = session_handle # does the work necessary to load file into filesystem # in the future, get_file_handle() could also be supported doc.get_file() @@ -39,7 +54,7 @@ def process_document(doc: "IngestDoc", **partition_kwargs) -> Optional[List[Dict isd_elems_no_filename = doc.process_file(**partition_kwargs) # Note, this may be a no-op if the IngestDoc doesn't do anything to persist - # the results. Instead, the MainProcess (caller) may work with the aggregate + # the results. Instead, the Processor (caller) may work with the aggregate # results across all docs in memory. doc.write_result() except Exception: diff --git a/unstructured/ingest/interfaces.py b/unstructured/ingest/interfaces.py index 2b1fd6bbcb..cc686abfb8 100644 --- a/unstructured/ingest/interfaces.py +++ b/unstructured/ingest/interfaces.py @@ -18,6 +18,12 @@ from unstructured.staging.base import convert_to_dict +@dataclass +class BaseSessionHandle(ABC): + """Abstract Base Class for sharing resources that are local to an individual process. + e.g., a connection for making a request for fetching documents.""" + + @dataclass class ProcessorConfigs: """Common set of config required when running data connectors.""" @@ -330,3 +336,26 @@ def cleanup_file(self): ): logger.debug(f"Cleaning up {self}") os.unlink(self.filename) + + +class ConfigSessionHandleMixin: + @abstractmethod + def create_session_handle(self) -> BaseSessionHandle: + """Creates a session handle that will be assigned on each IngestDoc to share + session related resources across all document handling for a given subprocess.""" + + +class IngestDocSessionHandleMixin: + config: ConfigSessionHandleMixin + _session_handle: Optional[BaseSessionHandle] = None + + @property + def session_handle(self): + """If a session handle is not assigned, creates a new one and assigns it.""" + if self._session_handle is None: + self._session_handle = self.config.create_session_handle() + return self._session_handle + + @session_handle.setter + def session_handle(self, session_handle: BaseSessionHandle): + self._session_handle = session_handle