Skip to content

Commit

Permalink
Trying ingested images
Browse files Browse the repository at this point in the history
  • Loading branch information
Ansh5461 committed Apr 25, 2024
1 parent 2d58da9 commit 073666b
Show file tree
Hide file tree
Showing 8 changed files with 474 additions and 20 deletions.
19 changes: 18 additions & 1 deletion querent/core/base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from querent.callback.event_callback_interface import EventCallbackInterface
from querent.common.types.ingested_images import IngestedImages
from querent.common.types.ingested_messages import IngestedMessages
from querent.common.types.ingested_table import IngestedTables
from querent.common.types.ingested_tokens import IngestedTokens
from querent.common.types.ingested_code import IngestedCode
from querent.common.types.querent_event import EventState, EventType
Expand Down Expand Up @@ -120,6 +121,18 @@ async def process_code(self, data: IngestedCode):
"""
raise NotImplementedError

@abstractmethod
async def process_tables(self, data: IngestedTables):
"""
Process tables asynchronously.
Args:
data (IngestedTables): The input data to process.
Returns:
EventState: The state of the event is set with the event type and the timestamp
of the event and set using `self.set_state(event_state)`.
"""
pass

@abstractmethod
async def process_images(self, data: IngestedImages):
"""
Expand Down Expand Up @@ -229,9 +242,13 @@ async def _inner_worker():
elif isinstance(data, IngestedTokens):
await self.process_tokens(data)
elif isinstance(data, IngestedImages):
print("Got an image from queue--------------------------------------------------------------------------\n\n", data.ocr_text)
await self.process_images(data)
elif isinstance(data, IngestedCode):
await self.process_code(data)
elif isinstance(data, IngestedTables):
continue
# await self.process_tables(data)
elif data is None:
none_counter += 1
if none_counter >= 2:
Expand All @@ -241,7 +258,7 @@ async def _inner_worker():

else:
raise Exception(
f"Invalid data type {type(data)} for {self.__class__.__name__}. Supported type: {IngestedTokens, IngestedMessages}"
f"Invalid data type {type(data)} for {self.__class__.__name__}. Supported type: {IngestedTokens, IngestedMessages, IngestedTables, IngestedImages}"
)
except Exception as e:
self.logger.error(
Expand Down
6 changes: 5 additions & 1 deletion querent/core/transformers/bert_ner_opensourcellm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from unidecode import unidecode
from transformers import AutoTokenizer
from querent.common.types.ingested_table import IngestedTables
from querent.kg.ner_helperfunctions.fixed_predicate import FixedPredicateExtractor
from querent.common.types.ingested_images import IngestedImages
from querent.config.core.opensource_llm_config import Opensource_LLM_Config
Expand Down Expand Up @@ -112,8 +113,11 @@ def validate(self) -> bool:
def process_messages(self, data: IngestedMessages):
return super().process_messages(data)

def process_images(self, data: IngestedImages):
async def process_images(self, data: IngestedImages):
return super().process_images(data)

async def process_tables(self, data: IngestedTables):
return super().process_tables(data)

async def process_code(self, data: IngestedCode):
return super().process_code(data)
Expand Down
124 changes: 122 additions & 2 deletions querent/core/transformers/fixed_entities_set_opensourcellm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from unidecode import unidecode
from transformers import AutoTokenizer
from querent.common.types.ingested_table import IngestedTables
from querent.kg.ner_helperfunctions.fixed_predicate import FixedPredicateExtractor
from querent.common.types.ingested_images import IngestedImages
from querent.config.core.opensource_llm_config import Opensource_LLM_Config
Expand Down Expand Up @@ -94,8 +95,119 @@ def validate(self) -> bool:
def process_messages(self, data: IngestedMessages):
return super().process_messages(data)

def process_images(self, data: IngestedImages):
return super().process_images(data)
def process_tables(self, data: IngestedTables):
pass

async def process_images(self, data: IngestedImages):
doc_entity_pairs = []
doc_entity_pairs_ocr = []
entities_list = []
final_entities_list = []
number_sentences = 0
try:
doc_source = data.doc_source
if not Fixed_Entities_LLM.validate_ingested_images(data):
self.set_termination_event()
return
if data.ocr_text:
ocr_text = ' '.join(data.ocr_text)
else:
ocr_text = data.ocr_text

if data.text:
clean_text = ' '.join(data.text)
else:
clean_text = data.text

file, content = data.file, clean_text

ocr_content = ocr_text

if ocr_content:
if self.fixed_entities:
ocr_content = self.entity_context_extractor.find_entity_sentences(ocr_content)
ocr_tokens = self.ner_llm_instance._tokenize_and_chunk(ocr_content)
for tokenized_sentence, original_sentence, sentence_idx in ocr_tokens:
(entities, entity_pairs,) = self.ner_llm_instance.extract_entities_from_sentence(original_sentence, sentence_idx, [s[1] for s in ocr_tokens],self.isConfinedSearch, self.fixed_entities, self.sample_entities)
print("Entities ---------", entities)
print("Entities pairs ---------------------------", entity_pairs)
if entity_pairs:
doc_entity_pairs_ocr.append(self.ner_llm_instance.transform_entity_pairs(entity_pairs))
else:
continue
number_sentences = number_sentences + 1

print("Doc entity pairs --------", doc_entity_pairs)

if len(doc_entity_pairs_ocr) == 0 and len(ocr_content) != 0:
if content:
if self.fixed_entities:
content = self.entity_context_extractor.find_entity_sentences(content)
tokens = self.ner_llm_instance._tokenize_and_chunk(content)
doc_entity_pairs_ocr = self.ner_llm_instance.extract_entities_from_sentence_for_given_sentence(ocr_content, sentence_idx, [s[1] for s in tokens],self.isConfinedSearch, self.fixed_entities, self.sample_entities)
print("doc_entity_pairs_ocr-----------------------", doc_entity_pairs_ocr)
for tokenized_sentence, original_sentence, sentence_idx in tokens:
#return list of entities from document, and entity pair
print("Here in side fo loop")
(entities, entity_pairs,) = self.ner_llm_instance.extract_entities_from_chunk(original_sentence, sentence_idx, [s[1] for s in tokens],self.isConfinedSearch, self.fixed_entities, self.sample_entities)
print("Entity pairs found from content", entity_pairs)
print("Entities found from content", entities)
if entity_pairs:

doc_entity_pairs.append(self.ner_llm_instance.transform_entity_pairs(entity_pairs))
entities_list.append(entities)
number_sentences = number_sentences + 1
#process those entities and ocr entity here
#if FE, then find the one most occuring
#if not FE, find the entity pair, where 1 entity is OCR text, and other is any other entity, which is most occuring, or which has higher confidence
final_entities_list = self.ner_llm_instance.create_subject_object_sentence_tuples(doc_entity_pairs_ocr, entities_list)


elif len(ocr_content) == 0:
#highest confidence entity pair from page text
sample_entity_pair = [{'entity': 'Image', 'label': 'image_data', 'score': 1.0, 'start_idx': 1, 'noun_chunk': 'image', 'noun_chunk_length': 1}]
final_entities_list = self.ner_llm_instance.create_subject_object_sentence_tuples(sample_entity_pair, entities_list)


print("Final entities ------", final_entities_list)
#-

if self.sample_entities:
doc_entity_pairs = self.entity_context_extractor.process_entity_types(doc_entities=final_entities_list)
if doc_entity_pairs and any(doc_entity_pairs):
doc_entity_pairs = self.ner_llm_instance.remove_duplicates(final_entities_list)
filtered_triples = process_data(doc_entity_pairs, file)
if not filtered_triples:
self.logger.debug("No entity pairs")
return
elif not self.skip_inferences:
relationships = self.semantic_extractor.process_tokens(filtered_triples)
self.logger.debug(f"length of relationships {len(relationships)}")
relationships = self.semantictriplefilter.filter_triples(relationships)
if len(relationships) > 0:
embedding_triples = self.create_emb.generate_embeddings(relationships)
if self.sample_relationships:
embedding_triples = self.predicate_context_extractor.process_predicate_types(embedding_triples)
for triple in embedding_triples:
if not self.termination_event.is_set():
graph_json = json.dumps(TripleToJsonConverter.convert_graphjson(triple))
if graph_json:
current_state = EventState(EventType.Graph,1.0, graph_json, file, doc_source=doc_source)
await self.set_state(new_state=current_state)
vector_json = json.dumps(TripleToJsonConverter.convert_vectorjson(triple))
if vector_json:
current_state = EventState(EventType.Vector,1.0, vector_json, file, doc_source=doc_source)
await self.set_state(new_state=current_state)
else:
return
else:
return
else:
return filtered_triples, file
else:
return
except Exception as e:
self.logger.debug(f"Invalid {self.__class__.__name__} configuration. Unable to process tokens. {e}")

async def process_code(self, data: IngestedCode):
return super().process_code(data)
Expand All @@ -107,6 +219,14 @@ def validate_ingested_tokens(data: IngestedTokens) -> bool:
return False

return True

@staticmethod
def validate_ingested_images(data: IngestedImages) -> bool:
if data.is_error():

return False

return True

def count_entity_pairs(self, doc_entity_pairs):
total_pairs = 0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import re
from querent.common.types.ingested_table import IngestedTables
from querent.core.transformers.fixed_entities_set_opensourcellm import Fixed_Entities_LLM
from querent.kg.ner_helperfunctions.fixed_predicate import FixedPredicateExtractor
from querent.config.core.gpt_llm_config import GPTConfig
Expand Down Expand Up @@ -96,11 +97,29 @@ def validate(self) -> bool:
def process_messages(self, data: IngestedMessages):
return super().process_messages(data)

def process_images(self, data: IngestedImages):
return super().process_messages(data)
async def process_images(self, data: IngestedImages):
try:
if not GPTLLM.validate_ingested_images(data):
self.set_termination_event()
return

doc_source = data.doc_source
relationships = []
unique_keys = set()
result = await self.llm_instance.process_images(data)
if not result:
return

return None

except Exception as e:
self.logger.debug(f"Invalid {self.__class__.__name__} configuration. Unable to process tokens. {e}")

async def process_code(self, data: IngestedCode):
return super().process_messages(data)

async def process_tables(self, data: IngestedTables):
return super().process_tables(data)

@staticmethod
def validate_ingested_tokens(data: IngestedTokens) -> bool:
Expand All @@ -109,6 +128,14 @@ def validate_ingested_tokens(data: IngestedTokens) -> bool:
return False

return True

@staticmethod
def validate_ingested_images(data: IngestedImages) -> bool:
if data.is_error():

return False

return True
def extract_semantic_triples(self, chat_completion):
# Extract the message content from the ChatCompletion
message_content = chat_completion.choices[0].message.content.replace('\n', '')
Expand Down
16 changes: 8 additions & 8 deletions querent/ingestors/doc/doc_ingestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ async def extract_text_from_doc(self, collected_bytes: CollectedBytes, doc_sourc
)

i = 1
for table in doc.tables:
table_data = []
for row in table.rows:
row_data = []
for cell in row.cells:
row_data.append(cell.text.strip())
table_data.append(row_data)
yield IngestedTables(file=collected_bytes.file, table = table_data, page_num = i, text = text, error=None)
# for table in doc.tables:
# table_data = []
# for row in table.rows:
# row_data = []
# for cell in row.cells:
# row_data.append(cell.text.strip())
# table_data.append(row_data)
# yield IngestedTables(file=collected_bytes.file, table = table_data, page_num = i, text = text, error=None)

i = 1
for rel in doc.part.rels.values():
Expand Down
10 changes: 6 additions & 4 deletions querent/ingestors/pdfs/pdf_ingestor_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import fitz
from PIL import Image
import io
import base64

import pybase64
import pytesseract
Expand Down Expand Up @@ -111,8 +112,8 @@ async def extract_and_process_pdf(
doc_source=doc_source,
)

async for table in self.extract_table(collected_bytes):
yield table
# async for table in self.extract_table(collected_bytes):
# yield table

async for imgae_data in self.extract_img(loader, collected_bytes.file, collected_bytes.data):
yield imgae_data
Expand Down Expand Up @@ -170,8 +171,8 @@ async def extract_img(self, doc, file_path, data):

yield IngestedImages(
file=file_path,
image=pybase64.b64encode(data),
image_name=f"{str(uuid.UUID)}.{image_ext}",
image=base64.b64encode(image_data).decode('utf-8'),
image_name=f"{str(uuid.uuid4())}.{image_ext}",
page_num=page_num,
text=[text_content],
coordinates=None,
Expand All @@ -183,6 +184,7 @@ async def get_ocr_from_image(self, image):
try:
image = Image.open(io.BytesIO(image))
text = pytesseract.image_to_string(image)
print("Got text from images ---------------------------")
except Exception as e:
self.logger.error("Exception-{e}")
raise e
Expand Down
Loading

0 comments on commit 073666b

Please sign in to comment.