diff --git a/querent/config/core/llm_config.py b/querent/config/core/llm_config.py index 1ae20efc..744475e6 100644 --- a/querent/config/core/llm_config.py +++ b/querent/config/core/llm_config.py @@ -18,11 +18,7 @@ class LLM_Config(EngineConfig): rel_model_path: str = './tests/llama-2-7b-chat.Q5_K_M.gguf' grammar_file_path: str = './querent/kg/rel_helperfunctions/json.gbnf' emb_model_name: str = 'sentence-transformers/all-MiniLM-L6-v2' - user_context: str = Field(default="""Please analyze the provided context and two entities. Use this information to answer the users query below. -Context: {context} -Entity 1: {entity1} and Entity 2: {entity2} -Query: In a semantic triple (Subject, Predicate & Object) framework, determine which of the above entity is the subject and which is the object based on the context along with the predicate between these entities. Please also identify the subject type, object type & predicate type. -Answer:""") + user_context: str = Field(default="In a semantic triple (Subject, Predicate & Object) framework, determine which of the above entity is the subject and which is the object based on the context along with the predicate between these entities. Please also identify the subject type, object type & predicate type.") enable_filtering: bool = False filter_params: dict = Field(default_factory=lambda: { 'score_threshold': 0.6, diff --git a/querent/core/transformers/bert_ner_opensourcellm.py b/querent/core/transformers/bert_ner_opensourcellm.py index e16a686c..e72d48e2 100644 --- a/querent/core/transformers/bert_ner_opensourcellm.py +++ b/querent/core/transformers/bert_ner_opensourcellm.py @@ -151,8 +151,6 @@ async def process_tokens(self, data: IngestedTokens): if content: if self.fixed_entities: content = self.entity_context_extractor.find_entity_sentences(content) - # if self.fixed_relationships: - # content = self.predicate_context_extractor.find_predicate_sentences(content) tokens = self.ner_llm_instance._tokenize_and_chunk(content) for tokenized_sentence, original_sentence, sentence_idx in tokens: (entities, entity_pairs,) = self.ner_llm_instance.extract_entities_from_sentence(original_sentence, sentence_idx, [s[1] for s in tokens],self.isConfinedSearch, self.fixed_entities, self.sample_entities) @@ -164,7 +162,6 @@ async def process_tokens(self, data: IngestedTokens): if self.sample_entities: doc_entity_pairs = self.entity_context_extractor.process_entity_types(doc_entities=doc_entity_pairs) if any(doc_entity_pairs): - print("Found doc_entity_pairs-------------------------------------", len(doc_entity_pairs)) doc_entity_pairs = self.ner_llm_instance.remove_duplicates(doc_entity_pairs) pairs_withattn = self.attn_scores_instance.extract_and_append_attention_weights(doc_entity_pairs) if self.enable_filtering == True and not self.entity_context_extractor and self.count_entity_pairs(pairs_withattn)>1 and not self.predicate_context_extractor: @@ -173,7 +170,6 @@ async def process_tokens(self, data: IngestedTokens): else: pairs_withemb = pairs_withattn pairs_with_predicates = process_data(pairs_withemb, file) - print("Found doc_entity_pairs-------------------------------------", len(pairs_with_predicates)) if self.enable_filtering == True and not self.entity_context_extractor and self.count_entity_pairs(pairs_withattn)>1 and not self.predicate_context_extractor: cluster_output = self.triple_filter.cluster_triples(pairs_with_predicates) clustered_triples = cluster_output['filtered_triples'] @@ -186,20 +182,16 @@ async def process_tokens(self, data: IngestedTokens): filtered_triples = pairs_with_predicates else: filtered_triples = pairs_with_predicates - print("Found doc_entity_pairs-------------------------------------", len(filtered_triples)) if not filtered_triples: self.logger.debug("No entity pairs") return elif not self.skip_inferences: - print("Extracting Entities-------------------------------------", filtered_triples) - relationships = self.semantic_extractor.process_tokens(filtered_triples[:5]) + relationships = self.semantic_extractor.process_tokens(filtered_triples, fixed_entities=(len(self.sample_entities) >= 1)) relationships = self.semantictriplefilter.filter_triples(relationships) - print("Relationships: {}".format(relationships)) if len(relationships) > 0: if self.fixed_relationships and self.sample_relationships: embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True, generate_embeddings_with_fixed_relationship = True) elif self.sample_relationships: - print("Only for sample_relationships") embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True) else: embedding_triples = self.create_emb.generate_embeddings(relationships) @@ -222,5 +214,4 @@ async def process_tokens(self, data: IngestedTokens): else: return filtered_triples, file except Exception as e: - print("Exception Caught: %s" % e) self.logger.debug(f"Invalid {self.__class__.__name__} configuration. Unable to process tokens. {e}") diff --git a/querent/core/transformers/fixed_entities_set_opensourcellm.py b/querent/core/transformers/fixed_entities_set_opensourcellm.py index 40fe8aa5..57395163 100644 --- a/querent/core/transformers/fixed_entities_set_opensourcellm.py +++ b/querent/core/transformers/fixed_entities_set_opensourcellm.py @@ -149,8 +149,6 @@ async def process_tokens(self, data: IngestedTokens): if content: if self.fixed_entities: content = self.entity_context_extractor.find_entity_sentences(content) - # if self.fixed_relationships: - # content = self.predicate_context_extractor.find_predicate_sentences(content) tokens = self.ner_llm_instance._tokenize_and_chunk(content) for tokenized_sentence, original_sentence, sentence_idx in tokens: (entities, entity_pairs,) = self.ner_llm_instance.extract_entities_from_sentence(original_sentence, sentence_idx, [s[1] for s in tokens],self.isConfinedSearch, self.fixed_entities, self.sample_entities) @@ -176,7 +174,6 @@ async def process_tokens(self, data: IngestedTokens): if self.fixed_relationships and self.sample_relationships: embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True, generate_embeddings_with_fixed_relationship = True) elif self.sample_relationships: - print("Only for sample_relationships") embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True) else: embedding_triples = self.create_emb.generate_embeddings(relationships) diff --git a/querent/core/transformers/gpt_llm_bert_ner_or_fixed_entities_set_ner.py b/querent/core/transformers/gpt_llm_bert_ner_or_fixed_entities_set_ner.py index 17a2f8ff..87b8a5a5 100644 --- a/querent/core/transformers/gpt_llm_bert_ner_or_fixed_entities_set_ner.py +++ b/querent/core/transformers/gpt_llm_bert_ner_or_fixed_entities_set_ner.py @@ -58,7 +58,9 @@ def __init__( is_confined_search = config.is_confined_search, huggingface_token = config.huggingface_token, spacy_model_path = config.spacy_model_path, - nltk_path = config.nltk_path) + nltk_path = config.nltk_path, + fixed_relationships = config.fixed_relationships, + sample_relationships = config.sample_relationships) self.fixed_entities = config.fixed_entities self.is_confined_search = config.is_confined_search self.fixed_relationships = config.fixed_relationships @@ -144,7 +146,7 @@ def remove_items_from_tuples(data: List[Tuple[str, str, str]]) -> List[Tuple[str async def process_triples(self, context, entity1, entity2, entity1_label, entity2_label): try: - if not self.user_context: + if not self.user_context and not self.fixed_entities: identify_entity_message = f"""Please analyze the provided context below. Once you have understood the context, answer the user query using the specified output format. Context: {context} @@ -163,7 +165,29 @@ async def process_triples(self, context, entity1, entity2, entity1_label, entity """ messages_classify_entity = [ {"role": "user", "content": identify_entity_message}, - {"role": "user", "content": "Query: First, identify all geological entities in the provided context. Then, create relevant semantic triples (Subject, Predicate, Object) and also categorize the respective the Subject, Object types (e.g. location, person, event, material, process etc.) and Predicate type. Use the above output format to provide all the relevant semantic triples."}, + {"role": "user", "content": "Query : In the context of a semantic triple framework, first identify which entity is subject and which is the object along with their respective types. Also determine the predicate and predicate type."}, + ] + elif not self.user_context and self.fixed_entities : + identify_entity_message = f"""Please analyze the provided context below. Once you have understood the context, answer the user query using the specified output format. + + Context: {context} + Entity 1: {entity1} and Entity 2: {entity2} + Entity 1_Type: {entity1_label} and Entity 2_Type: {entity2_label} + Output Format: + [ + {{ + 'subject': 'Identified as the main entity in the context, typically the initiator or primary focus of the action or topic being discussed.', + 'predicate': 'The relationship (predicate) between the subject and the object.', + 'object': 'This parameter represents the entity in the context directly impacted by or involved in the action, typically the recipient or target of the main verb's action.', + 'subject_type': 'The category of the subject entity e.g. location, person, event, material, process etc.', + 'object_type': 'The category of the object entity e.g. location, person, event, material, process etc.', + 'predicate_type': 'The category of the predicate e.g. causative, action, ownership, occurance etc.' + }}, + ] + """ + messages_classify_entity = [ + {"role": "user", "content": identify_entity_message}, + {"role": "user", "content": "Query : In the context of a semantic triple framework, first identify which entity is subject and which is the object and also validate and output their their respective types. Also determine the predicate and predicate type."}, ] elif self.user_context and self.fixed_entities : identify_entity_message = f"""Please analyze the provided context below. Once you have understood the context, answer the user query using the specified output format. @@ -253,14 +277,20 @@ def generate_output_tuple(self,result, context_json): ) return output_tuple - + + def extract_key(tup): + subject, json_string, obj = tup + data = json.loads(json_string.replace("\n", "")) + return (subject, data.get('predicate'), obj) + async def process_tokens(self, data: IngestedTokens): try: if not GPTLLM.validate_ingested_tokens(data): self.set_termination_event() return relationships = [] - result = await self.llm_instance.process_tokens(data) + unique_keys = set() + result = await self.llm_instance.process_tokens(data) if not result: return else: filtered_triples, file = result @@ -275,12 +305,14 @@ async def process_tokens(self, data: IngestedTokens): result = await self.process_triples(context, entity1_nn_chunk, entity2_nn_chunk, entity1_label, entity2_label) if result: output_tuple = self.generate_output_tuple(result, context_json) - relationships.append(output_tuple) + key = GPTLLM.extract_key(output_tuple) + if key not in unique_keys: + unique_keys.add(key) + relationships.append(output_tuple) if len(relationships) > 0: if self.fixed_relationships and self.sample_relationships: embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True, generate_embeddings_with_fixed_relationship = True) elif self.sample_relationships: - print("Only for sample_relationships") embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True) else: embedding_triples = self.create_emb.generate_embeddings(relationships) diff --git a/querent/core/transformers/relationship_extraction_llm.py b/querent/core/transformers/relationship_extraction_llm.py index 90506edb..06515c96 100644 --- a/querent/core/transformers/relationship_extraction_llm.py +++ b/querent/core/transformers/relationship_extraction_llm.py @@ -103,13 +103,13 @@ def validate(self, data) -> bool: self.logger.error(f"Error in validation: {e}") return False - def process_tokens(self, payload): + def process_tokens(self, payload, fixed_entities = False): try: triples = payload trimmed_triples = self.normalizetriples_buildindex(triples) if self.rag_approach == True: self.rag_retriever.build_faiss_index(trimmed_triples) - relationships = self.extract_relationships(triples) + relationships = self.extract_relationships(triples, fixed_entities) return relationships @@ -149,33 +149,6 @@ def create_semantic_triple(self, input1, input2): }), input1.get("object","") ) - # else: - # if input1.get("subject","").lower() in input2_data.get("entity1_nn_chunk","").lower or input2_data.get("entity1_nn_chunk","").lower in input1.get("subject","").lower() or input2_data.get("entity1_nn_chunk","").lower == input1.get("subject","").lower(): - # triple = ( - # input1.get("subject",""), - # json.dumps({ - # "predicate": input1.get("predicate",""), - # "predicate_type": input1.get("predicate_type","Unlabeled"), - # "context": input2_data.get("context", ""), - # "file_path": input2_data.get("file_path", ""), - # "subject_type": input2_data.get("entity1_label","Unlabeled"), - # "object_type": input2_data.get("entity2_label","Unlabeled") - # }), - # input1.get("object","") - # ) - # else: - # triple = ( - # input1.get("subject",""), - # json.dumps({ - # "predicate": input1.get("predicate",""), - # "predicate_type": input1.get("predicate_type","Unlabeled"), - # "context": input2_data.get("context", ""), - # "file_path": input2_data.get("file_path", ""), - # "subject_type": input2.get("entity2_label","Unlabeled"), - # "object_type": input1.get("entity1_label","Unlabeled") - # }), - # input1.get("object","") - # ) return triple except Exception as e: self.logger.error(f"Error in creating semantic triple: {e}") @@ -195,7 +168,7 @@ def replace_entities(self, text, entity1, entity2): return data - def extract_relationships(self, triples): + def extract_relationships(self, triples, fixed_entities = False): try: self.logger.debug(f"Length of identified triples {len(triples)}") updated_triples = [] @@ -210,17 +183,34 @@ def extract_relationships(self, triples): top_docs = self.rag_retriever.retrieve_documents(db, prompt=prompt) documents = top_docs else: - query = """Please analyze the provided context and two entities. Use this information to answer the users query below. + if fixed_entities == False: + query = """Please analyze the provided context and two entities. Use this information to answer the users query below. Context: {context} Entity 1: {entity1} and Entity 2: {entity2} Query:{question} -Answer:""" - if not self.config.qa_template: - question = "In the context of reservoir studies, identify the subject, predicate, and object in a semantic triple framework, focusing on reservoir attributes (e.g., porosity, permeability), processes (e.g., influences, determines), and outcomes (e.g., recovery efficiency). Specify the types for the subject (attribute), predicate (process), and object (outcome)." - query = query.format(question = question, context = context, entity1=predicate.get('entity1_nn_chunk', ''), entity2=predicate.get('entity2_nn_chunk', '')) +Answer:""" + if not self.config.qa_template: + question = "In the context of a semantic triple framework, first identify which entity is subject and which is the object along with their respective types. Also determine the predicate and predicate type." + else: + question = self.config.qa_template + query = query.format(question = question, context = context, entity1=predicate.get('entity1_nn_chunk', ''), entity2=predicate.get('entity2_nn_chunk', '')) else: - question = self.config.qa_template - query = query.format(question = question, context = context, entity1=predicate.get('entity1_nn_chunk', ''), entity2=predicate.get('entity2_nn_chunk', '')) + query = """Please analyze the provided context and two entities along with their identified labels. Use this information to answer the users query below. +Context: {context} +Entity 1: {entity1} and Entity 1_label: {entity1_label} +Entity 2: {entity2} and Entity 2_label: {entity2_label} +Query:{question} +Answer:""" + if not self.config.qa_template: + question = "In the context of a semantic triple framework, first identify which entity is subject and which is the object, validate and output their respective types. Also determine the predicate and predicate type." + else: + question = self.config.qa_template + query = query.format(question = question, + context = context, + entity1=predicate.get('entity1_nn_chunk', ''), + entity2=predicate.get('entity2_nn_chunk', ''), + entity1_label=predicate.get('entity1_label', ''), + entity2_label=predicate.get('entity2_label', '')) answer_relation = self.qa_system.ask_question(prompt=query, llm=self.qa_system.llm, grammar=self.grammar) try: choices_text = answer_relation['choices'][0]['text'] @@ -242,6 +232,8 @@ def trim_triples(self, data): 'context': predicate_dict.get('context', ''), 'entity1_nn_chunk': predicate_dict.get('entity1_nn_chunk', ''), 'entity2_nn_chunk': predicate_dict.get('entity2_nn_chunk', ''), + 'entity1_label': predicate_dict.get('entity1_label', ''), + 'entity2_label': predicate_dict.get('entity2_label', ''), 'file_path': predicate_dict.get('file_path', '') } trimmed_data.append((entity1, trimmed_predicate, entity2)) diff --git a/querent/kg/ner_helperfunctions/fixed_predicate.py b/querent/kg/ner_helperfunctions/fixed_predicate.py index 2cc62f9d..2ad05d12 100644 --- a/querent/kg/ner_helperfunctions/fixed_predicate.py +++ b/querent/kg/ner_helperfunctions/fixed_predicate.py @@ -157,12 +157,9 @@ def construct_predicate_json(self, relationships=None, relationship_types=None): def update_embedding_triples_with_similarity(self, predicate_json_emb, embedding_triples): try: - print("Updating embedding------------------------------") predicate_json_emb = [json.loads(item) for item in predicate_json_emb] predicate_emb_list = [item["predicate_emb"] for item in predicate_json_emb if item["predicate_emb"] != "Not Implemented"] - print("Updating embedding------------------------------1") predicate_emb_matrix = np.array(predicate_emb_list) - print("Updating embedding------------------------------2") updated_embedding_triples = [] for triple in embedding_triples: entity, triple_json, study_field = triple @@ -176,8 +173,7 @@ def update_embedding_triples_with_similarity(self, predicate_json_emb, embedding similarities = cosine_similarity(current_predicate_emb, predicate_emb_matrix) max_similarity_index = np.argmax(similarities) most_similar_predicate_details = predicate_json_emb[max_similarity_index] - print("Score: ", similarities[0][max_similarity_index]) - if similarities[0][max_similarity_index] > 0.4: + if similarities[0][max_similarity_index] > 0.5: triple_data["predicate_type"] = most_similar_predicate_details["type"] if most_similar_predicate_details["relationship"].lower() != "unlabelled": triple_data["predicate"] = most_similar_predicate_details["relationship"] diff --git a/querent/kg/rel_helperfunctions/embedding_store.py b/querent/kg/rel_helperfunctions/embedding_store.py index fedba29e..c340c7dc 100644 --- a/querent/kg/rel_helperfunctions/embedding_store.py +++ b/querent/kg/rel_helperfunctions/embedding_store.py @@ -144,9 +144,9 @@ def generate_embeddings(self, payload, relationship_finder=False, generate_embed try: triples = payload processed_pairs = [] - for entity, json_string, related_entity in triples: try: + json_string = json_string.replace("\n", "") data = json.loads(json_string) context = data.get("context", "").replace('"', '\\"') predicate = data.get("predicate","").replace('"', '\\"') @@ -159,7 +159,6 @@ def generate_embeddings(self, payload, relationship_finder=False, generate_embed if relationship_finder and generate_embeddings_with_fixed_relationship: predicate_embedding = self.get_embeddings([predicate + " ("+predicate_type+")"])[0] elif relationship_finder: - print("Predicate ----------------------------", self.get_embeddings([predicate_type])) predicate_embedding = self.get_embeddings([predicate_type])[0] essential_data = { "context": context, diff --git a/tests/workflows/bert_llm_test_fixed_entities_predicates_workflow.py b/tests/workflows/bert_llm_test_fixed_entities_predicates_workflow.py index cb9a24c7..a67ca384 100644 --- a/tests/workflows/bert_llm_test_fixed_entities_predicates_workflow.py +++ b/tests/workflows/bert_llm_test_fixed_entities_predicates_workflow.py @@ -1,74 +1,76 @@ -import asyncio -from asyncio import Queue -import json -from pathlib import Path -from querent.callback.event_callback_interface import EventCallbackInterface -from querent.collectors.fs.fs_collector import FSCollectorFactory -from querent.common.types.ingested_tokens import IngestedTokens -from querent.common.types.querent_event import EventState, EventType -from querent.config.collector.collector_config import FSCollectorConfig -from querent.common.uri import Uri -from querent.config.core.llm_config import LLM_Config -from querent.ingestors.ingestor_manager import IngestorFactoryManager -import pytest -import uuid -from querent.common.types.file_buffer import FileBuffer -from querent.core.transformers.bert_ner_opensourcellm import BERTLLM -from querent.querent.resource_manager import ResourceManager -from querent.querent.querent import Querent -import time +# import asyncio +# from asyncio import Queue +# import json +# from pathlib import Path +# from querent.callback.event_callback_interface import EventCallbackInterface +# from querent.collectors.fs.fs_collector import FSCollectorFactory +# from querent.common.types.ingested_tokens import IngestedTokens +# from querent.common.types.querent_event import EventState, EventType +# from querent.config.collector.collector_config import FSCollectorConfig +# from querent.common.uri import Uri +# from querent.config.core.llm_config import LLM_Config +# from querent.ingestors.ingestor_manager import IngestorFactoryManager +# import pytest +# import uuid +# from querent.common.types.file_buffer import FileBuffer +# from querent.core.transformers.bert_ner_opensourcellm import BERTLLM +# from querent.querent.resource_manager import ResourceManager +# from querent.querent.querent import Querent +# import time -@pytest.mark.asyncio -async def test_ingest_all_async(): - # Set up the collectors - directories = [ "./tests/data/llm/predicate_checker"] - collectors = [ - FSCollectorFactory().resolve( - Uri("file://" + str(Path(directory).resolve())), - FSCollectorConfig(config_source={ - "id": str(uuid.uuid4()), - "root_path": directory, - "name": "Local-config", - "config": {}, - "backend": "localfile", - "uri": "file://", - }), - ) - for directory in directories - ] +# @pytest.mark.asyncio +# async def test_ingest_all_async(): +# # Set up the collectors +# directories = [ "./tests/data/llm/predicate_checker"] +# collectors = [ +# FSCollectorFactory().resolve( +# Uri("file://" + str(Path(directory).resolve())), +# FSCollectorConfig(config_source={ +# "id": str(uuid.uuid4()), +# "root_path": directory, +# "name": "Local-config", +# "config": {}, +# "backend": "localfile", +# "uri": "file://", +# }), +# ) +# for directory in directories +# ] - # Set up the result queue - result_queue = asyncio.Queue() +# # Set up the result queue +# result_queue = asyncio.Queue() - # Create the IngestorFactoryManager - ingestor_factory_manager = IngestorFactoryManager( - collectors=collectors, result_queue=result_queue - ) - ingest_task = asyncio.create_task(ingestor_factory_manager.ingest_all_async()) - resource_manager = ResourceManager() - bert_llm_config = LLM_Config( - # ner_model_name="botryan96/GeoBERT", - enable_filtering=True, - filter_params={ - 'score_threshold': 0.5, - 'attention_score_threshold': 0.1, - 'similarity_threshold': 0.5, - 'min_cluster_size': 5, - 'min_samples': 3, - 'cluster_persistence_threshold':0.2 - } -# ,fixed_relationships=[ -# "Increase in research funding leads to environmental science focus", -# "Dr. Emily Stanton's advocacy for cleaner energy", -# "University's commitment to reduce carbon emissions", -# "Dr. Stanton's research influences architectural plans", -# "Collaborative project between sociology and environmental sciences", -# "Student government launches mental health awareness workshops", -# "Enhanced fitness programs improve sports teams' performance", -# "Coach Torres influences student-athletes' holistic health", -# "Partnership expands access to digital resources", -# "Interdisciplinary approach enriches academic experience" -# ] +# # Create the IngestorFactoryManager +# ingestor_factory_manager = IngestorFactoryManager( +# collectors=collectors, result_queue=result_queue +# ) +# ingest_task = asyncio.create_task(ingestor_factory_manager.ingest_all_async()) +# resource_manager = ResourceManager() +# bert_llm_config = LLM_Config( +# # ner_model_name="botryan96/GeoBERT", +# enable_filtering=True, +# filter_params={ +# 'score_threshold': 0.5, +# 'attention_score_threshold': 0.1, +# 'similarity_threshold': 0.5, +# 'min_cluster_size': 5, +# 'min_samples': 3, +# 'cluster_persistence_threshold':0.2 +# } +# # ,fixed_entities = ["university", "greenwood", "liam zheng", "department", "Metroville", "Emily Stanton", "Coach", "health", "training", "atheletes" ] +# # ,sample_entities=["organization", "organization", "person", "department", "city", "person", "person", "method", "method", "person"] +# # ,fixed_relationships=[ +# # "Increase in research funding leads to environmental science focus", +# # "Dr. Emily Stanton's advocacy for cleaner energy", +# # "University's commitment to reduce carbon emissions", +# # "Dr. Stanton's research influences architectural plans", +# # "Collaborative project between sociology and environmental sciences", +# # "Student government launches mental health awareness workshops", +# # "Enhanced fitness programs improve sports teams' performance", +# # "Coach Torres influences student-athletes' holistic health", +# # "Partnership expands access to digital resources", +# # "Interdisciplinary approach enriches academic experience" +# # ] # , sample_relationships=[ # "Causal", # "Contributory", @@ -80,29 +82,31 @@ async def test_ingest_all_async(): # "Influential", # "Collaborative", # "Enriching" -# ] - # user_context="Query: Your task is to analyze and interpret the context to construct semantic triples. The above context is from a geological research study on reservoirs and the above entities and their respective types have already been identified. Please Identify the entity which is the subject and the entity which is object based on the context, and determine the meaningful relationship or predicate linking the subject entity to the object entity. Determine whether the entity labels provided match the subject type and object type and correct if needed. Also provide the predicate type. Answer:" - ) - llm_instance = BERTLLM(result_queue, bert_llm_config) - class StateChangeCallback(EventCallbackInterface): - def handle_event(self, event_type: EventType, event_state: EventState): - if event_state['event_type'] == EventType.Graph: - triple = json.loads(event_state['payload']) - print("triple: {}".format(triple)) - assert isinstance(triple['subject'], str) and triple['subject'] - elif event_state['event_type'] == EventType.Vector: - triple = json.loads(event_state['payload']) - # print("triple: {}".format(triple)) - llm_instance.subscribe(EventType.Graph, StateChangeCallback()) - llm_instance.subscribe(EventType.Vector, StateChangeCallback()) - querent = Querent( - [llm_instance], - resource_manager=resource_manager, - ) - querent_task = asyncio.create_task(querent.start()) - await asyncio.gather(ingest_task, querent_task) +# ], +# # is_confined_search = True, + +# # user_context="Query: Your task is to analyze and interpret the context to construct semantic triples. The above context is from a geological research study on reservoirs and the above entities and their respective types have already been identified. Please Identify the entity which is the subject and the entity which is object based on the context, and determine the meaningful relationship or predicate linking the subject entity to the object entity. Determine whether the entity labels provided match the subject type and object type and correct if needed. Also provide the predicate type. Answer:" +# ) +# llm_instance = BERTLLM(result_queue, bert_llm_config) +# class StateChangeCallback(EventCallbackInterface): +# def handle_event(self, event_type: EventType, event_state: EventState): +# if event_state['event_type'] == EventType.Graph: +# triple = json.loads(event_state['payload']) +# print("triple: {}".format(triple)) +# assert isinstance(triple['subject'], str) and triple['subject'] +# elif event_state['event_type'] == EventType.Vector: +# triple = json.loads(event_state['payload']) +# # print("triple: {}".format(triple)) +# llm_instance.subscribe(EventType.Graph, StateChangeCallback()) +# llm_instance.subscribe(EventType.Vector, StateChangeCallback()) +# querent = Querent( +# [llm_instance], +# resource_manager=resource_manager, +# ) +# querent_task = asyncio.create_task(querent.start()) +# await asyncio.gather(ingest_task, querent_task) -if __name__ == "__main__": +# if __name__ == "__main__": - # Run the async function - asyncio.run(test_ingest_all_async()) +# # Run the async function +# asyncio.run(test_ingest_all_async()) diff --git a/tests/workflows/gpt_llm_test_fixed_entities_predicates_workflow.py b/tests/workflows/gpt_llm_test_fixed_entities_predicates_workflow.py new file mode 100644 index 00000000..2dcadacd --- /dev/null +++ b/tests/workflows/gpt_llm_test_fixed_entities_predicates_workflow.py @@ -0,0 +1,115 @@ +# import asyncio +# from asyncio import Queue +# import json +# from pathlib import Path +# from querent.callback.event_callback_interface import EventCallbackInterface +# from querent.collectors.fs.fs_collector import FSCollectorFactory +# from querent.common.types.ingested_tokens import IngestedTokens +# from querent.common.types.querent_event import EventState, EventType +# from querent.config.collector.collector_config import FSCollectorConfig +# from querent.common.uri import Uri +# from querent.config.core.llm_config import LLM_Config +# from querent.ingestors.ingestor_manager import IngestorFactoryManager +# import pytest +# import uuid +# from querent.common.types.file_buffer import FileBuffer +# from querent.core.transformers.bert_ner_opensourcellm import BERTLLM +# from querent.querent.resource_manager import ResourceManager +# from querent.querent.querent import Querent +# import time +# from querent.core.transformers.gpt_llm_bert_ner_or_fixed_entities_set_ner import GPTLLM +# from querent.config.core.gpt_llm_config import GPTConfig + +# @pytest.mark.asyncio +# async def test_ingest_all_async(): +# # Set up the collectors +# directories = [ "./tests/data/llm/predicate_checker"] +# collectors = [ +# FSCollectorFactory().resolve( +# Uri("file://" + str(Path(directory).resolve())), +# FSCollectorConfig(config_source={ +# "id": str(uuid.uuid4()), +# "root_path": directory, +# "name": "Local-config", +# "config": {}, +# "backend": "localfile", +# "uri": "file://", +# }), +# ) +# for directory in directories +# ] + +# # Set up the result queue +# result_queue = asyncio.Queue() + +# # Create the IngestorFactoryManager +# ingestor_factory_manager = IngestorFactoryManager( +# collectors=collectors, result_queue=result_queue +# ) +# ingest_task = asyncio.create_task(ingestor_factory_manager.ingest_all_async()) +# resource_manager = ResourceManager() +# bert_llm_config = GPTConfig( +# # ner_model_name="botryan96/GeoBERT", +# enable_filtering=True, +# openai_api_key="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", +# filter_params={ +# 'score_threshold': 0.5, +# 'attention_score_threshold': 0.1, +# 'similarity_threshold': 0.5, +# 'min_cluster_size': 5, +# 'min_samples': 3, +# 'cluster_persistence_threshold':0.2 +# } +# # ,fixed_entities = ["university", "greenwood", "liam zheng", "department", "Metroville", "Emily Stanton", "Coach", "health", "training", "atheletes" ] +# # ,sample_entities=["organization", "organization", "person", "department", "city", "person", "person", "method", "method", "person"] +# # ,fixed_relationships=[ +# # "Increase in research funding leads to environmental science focus", +# # "Dr. Emily Stanton's advocacy for cleaner energy", +# # "University's commitment to reduce carbon emissions", +# # "Dr. Stanton's research influences architectural plans", +# # "Collaborative project between sociology and environmental sciences", +# # "Student government launches mental health awareness workshops", +# # "Enhanced fitness programs improve sports teams' performance", +# # "Coach Torres influences student-athletes' holistic health", +# # "Partnership expands access to digital resources", +# # "Interdisciplinary approach enriches academic experience" +# # ] +# , sample_relationships=[ +# "Causal", +# "Contributory", +# "Causal", +# "Influential", +# "Collaborative", +# "Initiative", +# "Beneficial", +# "Influential", +# "Collaborative", +# "Enriching" +# ], +# # is_confined_search = True, + +# # user_context="Query: Your task is to analyze and interpret the context to construct semantic triples. The above context is from a geological research study on reservoirs and the above entities and their respective types have already been identified. Please Identify the entity which is the subject and the entity which is object based on the context, and determine the meaningful relationship or predicate linking the subject entity to the object entity. Determine whether the entity labels provided match the subject type and object type and correct if needed. Also provide the predicate type. Answer:" +# ) +# llm_instance = GPTLLM(result_queue, bert_llm_config) +# class StateChangeCallback(EventCallbackInterface): +# def handle_event(self, event_type: EventType, event_state: EventState): +# if event_state['event_type'] == EventType.Graph: +# triple = json.loads(event_state['payload']) +# print("triple: {}".format(triple)) +# assert isinstance(triple['subject'], str) and triple['subject'] +# elif event_state['event_type'] == EventType.Vector: +# triple = json.loads(event_state['payload']) +# # print("triple: {}".format(triple)) +# llm_instance.subscribe(EventType.Graph, StateChangeCallback()) +# llm_instance.subscribe(EventType.Vector, StateChangeCallback()) +# querent = Querent( +# [llm_instance], +# resource_manager=resource_manager, +# ) +# querent_task = asyncio.create_task(querent.start()) +# await asyncio.gather(ingest_task, querent_task) + +# if __name__ == "__main__": + +# # Run the async function +# asyncio.run(test_ingest_all_async())