From 1ceee396bbd580c1a009ea2dc0afff0dd3a42548 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Fri, 27 Sep 2024 09:19:47 -0600 Subject: [PATCH] fix --- llama_extract/base.py | 12 ++++++++---- pyproject.toml | 7 +++---- tests/test_extract.py | 18 +++++++++++++++++- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/llama_extract/base.py b/llama_extract/base.py index 036b4b1..e2dea96 100644 --- a/llama_extract/base.py +++ b/llama_extract/base.py @@ -1,10 +1,13 @@ import asyncio import os import time + +import pydantic.v1 as pydantic_v1 + from io import BufferedIOBase, BufferedReader, BytesIO from json.decoder import JSONDecodeError from pathlib import Path -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, Extra, ValidationError from typing import List, Optional, Tuple, Type, Union import urllib.parse @@ -212,15 +215,16 @@ async def ainfer_schema( ) if 200 <= _response.status_code < 300: - return pydantic.parse_obj_as(ExtractionSchema, _response.json()) # type: ignore + return pydantic_v1.parse_obj_as(ExtractionSchema, _response.json()) if _response.status_code == 422: raise UnprocessableEntityError( - pydantic.parse_obj_as(HttpValidationError, _response.json()) - ) # type: ignore + pydantic_v1.parse_obj_as(HttpValidationError, _response.json()) + ) try: _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, body=_response.text) + raise ApiError(status_code=_response.status_code, body=_response_json) def infer_schema( diff --git a/pyproject.toml b/pyproject.toml index e1cbb0b..4aac292 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "llama-extract" -version = "0.0.4" +version = "0.0.5" description = "Infer schema and extract data from unstructured files" authors = ["Logan Markewich "] license = "MIT" @@ -13,9 +13,8 @@ packages = [{include = "llama_extract"}] [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -llama-index-core = ">=0.10.29" -llama-cloud = "^0.0.11" -pydantic = ">=1.10" +llama-index-core = "^0.11.0" +llama-cloud = ">=0.1.0" [tool.poetry.group.dev.dependencies] pytest = "^8.0.0" diff --git a/tests/test_extract.py b/tests/test_extract.py index 815baf0..be9fb60 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -1,10 +1,26 @@ import os import pytest +from llama_extract import LlamaExtract + + +TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/test.pdf") + @pytest.mark.skipif( os.environ.get("LLAMA_CLOUD_API_KEY", "") == "", reason="LLAMA_CLOUD_API_KEY not set", ) def test_simple() -> None: - pass + extractor = LlamaExtract( + api_key=os.environ["LLAMA_CLOUD_API_KEY"], + ) + + # Infer schema + schema = extractor.infer_schema( + "my_schema", [TEST_FILE] + ) + + # Extract data + results = extractor.extract(schema.id, [TEST_FILE]) +