Skip to content

Commit

Permalink
working fixed and sample relationships
Browse files Browse the repository at this point in the history
  • Loading branch information
ngupta10 committed Apr 17, 2024
1 parent 0d2e910 commit 9f73b51
Show file tree
Hide file tree
Showing 9 changed files with 283 additions and 161 deletions.
6 changes: 1 addition & 5 deletions querent/config/core/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@ class LLM_Config(EngineConfig):
rel_model_path: str = './tests/llama-2-7b-chat.Q5_K_M.gguf'
grammar_file_path: str = './querent/kg/rel_helperfunctions/json.gbnf'
emb_model_name: str = 'sentence-transformers/all-MiniLM-L6-v2'
user_context: str = Field(default="""Please analyze the provided context and two entities. Use this information to answer the users query below.
Context: {context}
Entity 1: {entity1} and Entity 2: {entity2}
Query: In a semantic triple (Subject, Predicate & Object) framework, determine which of the above entity is the subject and which is the object based on the context along with the predicate between these entities. Please also identify the subject type, object type & predicate type.
Answer:""")
user_context: str = Field(default="In a semantic triple (Subject, Predicate & Object) framework, determine which of the above entity is the subject and which is the object based on the context along with the predicate between these entities. Please also identify the subject type, object type & predicate type.")
enable_filtering: bool = False
filter_params: dict = Field(default_factory=lambda: {
'score_threshold': 0.6,
Expand Down
11 changes: 1 addition & 10 deletions querent/core/transformers/bert_ner_opensourcellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ async def process_tokens(self, data: IngestedTokens):
if content:
if self.fixed_entities:
content = self.entity_context_extractor.find_entity_sentences(content)
# if self.fixed_relationships:
# content = self.predicate_context_extractor.find_predicate_sentences(content)
tokens = self.ner_llm_instance._tokenize_and_chunk(content)
for tokenized_sentence, original_sentence, sentence_idx in tokens:
(entities, entity_pairs,) = self.ner_llm_instance.extract_entities_from_sentence(original_sentence, sentence_idx, [s[1] for s in tokens],self.isConfinedSearch, self.fixed_entities, self.sample_entities)
Expand All @@ -164,7 +162,6 @@ async def process_tokens(self, data: IngestedTokens):
if self.sample_entities:
doc_entity_pairs = self.entity_context_extractor.process_entity_types(doc_entities=doc_entity_pairs)
if any(doc_entity_pairs):
print("Found doc_entity_pairs-------------------------------------", len(doc_entity_pairs))
doc_entity_pairs = self.ner_llm_instance.remove_duplicates(doc_entity_pairs)
pairs_withattn = self.attn_scores_instance.extract_and_append_attention_weights(doc_entity_pairs)
if self.enable_filtering == True and not self.entity_context_extractor and self.count_entity_pairs(pairs_withattn)>1 and not self.predicate_context_extractor:
Expand All @@ -173,7 +170,6 @@ async def process_tokens(self, data: IngestedTokens):
else:
pairs_withemb = pairs_withattn
pairs_with_predicates = process_data(pairs_withemb, file)
print("Found doc_entity_pairs-------------------------------------", len(pairs_with_predicates))
if self.enable_filtering == True and not self.entity_context_extractor and self.count_entity_pairs(pairs_withattn)>1 and not self.predicate_context_extractor:
cluster_output = self.triple_filter.cluster_triples(pairs_with_predicates)
clustered_triples = cluster_output['filtered_triples']
Expand All @@ -186,20 +182,16 @@ async def process_tokens(self, data: IngestedTokens):
filtered_triples = pairs_with_predicates
else:
filtered_triples = pairs_with_predicates
print("Found doc_entity_pairs-------------------------------------", len(filtered_triples))
if not filtered_triples:
self.logger.debug("No entity pairs")
return
elif not self.skip_inferences:
print("Extracting Entities-------------------------------------", filtered_triples)
relationships = self.semantic_extractor.process_tokens(filtered_triples[:5])
relationships = self.semantic_extractor.process_tokens(filtered_triples, fixed_entities=(len(self.sample_entities) >= 1))
relationships = self.semantictriplefilter.filter_triples(relationships)
print("Relationships: {}".format(relationships))
if len(relationships) > 0:
if self.fixed_relationships and self.sample_relationships:
embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True, generate_embeddings_with_fixed_relationship = True)
elif self.sample_relationships:
print("Only for sample_relationships")
embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True)
else:
embedding_triples = self.create_emb.generate_embeddings(relationships)
Expand All @@ -222,5 +214,4 @@ async def process_tokens(self, data: IngestedTokens):
else:
return filtered_triples, file
except Exception as e:
print("Exception Caught: %s" % e)
self.logger.debug(f"Invalid {self.__class__.__name__} configuration. Unable to process tokens. {e}")
3 changes: 0 additions & 3 deletions querent/core/transformers/fixed_entities_set_opensourcellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,6 @@ async def process_tokens(self, data: IngestedTokens):
if content:
if self.fixed_entities:
content = self.entity_context_extractor.find_entity_sentences(content)
# if self.fixed_relationships:
# content = self.predicate_context_extractor.find_predicate_sentences(content)
tokens = self.ner_llm_instance._tokenize_and_chunk(content)
for tokenized_sentence, original_sentence, sentence_idx in tokens:
(entities, entity_pairs,) = self.ner_llm_instance.extract_entities_from_sentence(original_sentence, sentence_idx, [s[1] for s in tokens],self.isConfinedSearch, self.fixed_entities, self.sample_entities)
Expand All @@ -176,7 +174,6 @@ async def process_tokens(self, data: IngestedTokens):
if self.fixed_relationships and self.sample_relationships:
embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True, generate_embeddings_with_fixed_relationship = True)
elif self.sample_relationships:
print("Only for sample_relationships")
embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True)
else:
embedding_triples = self.create_emb.generate_embeddings(relationships)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def __init__(
is_confined_search = config.is_confined_search,
huggingface_token = config.huggingface_token,
spacy_model_path = config.spacy_model_path,
nltk_path = config.nltk_path)
nltk_path = config.nltk_path,
fixed_relationships = config.fixed_relationships,
sample_relationships = config.sample_relationships)
self.fixed_entities = config.fixed_entities
self.is_confined_search = config.is_confined_search
self.fixed_relationships = config.fixed_relationships
Expand Down Expand Up @@ -144,7 +146,7 @@ def remove_items_from_tuples(data: List[Tuple[str, str, str]]) -> List[Tuple[str

async def process_triples(self, context, entity1, entity2, entity1_label, entity2_label):
try:
if not self.user_context:
if not self.user_context and not self.fixed_entities:
identify_entity_message = f"""Please analyze the provided context below. Once you have understood the context, answer the user query using the specified output format.
Context: {context}
Expand All @@ -163,7 +165,29 @@ async def process_triples(self, context, entity1, entity2, entity1_label, entity
"""
messages_classify_entity = [
{"role": "user", "content": identify_entity_message},
{"role": "user", "content": "Query: First, identify all geological entities in the provided context. Then, create relevant semantic triples (Subject, Predicate, Object) and also categorize the respective the Subject, Object types (e.g. location, person, event, material, process etc.) and Predicate type. Use the above output format to provide all the relevant semantic triples."},
{"role": "user", "content": "Query : In the context of a semantic triple framework, first identify which entity is subject and which is the object along with their respective types. Also determine the predicate and predicate type."},
]
elif not self.user_context and self.fixed_entities :
identify_entity_message = f"""Please analyze the provided context below. Once you have understood the context, answer the user query using the specified output format.
Context: {context}
Entity 1: {entity1} and Entity 2: {entity2}
Entity 1_Type: {entity1_label} and Entity 2_Type: {entity2_label}
Output Format:
[
{{
'subject': 'Identified as the main entity in the context, typically the initiator or primary focus of the action or topic being discussed.',
'predicate': 'The relationship (predicate) between the subject and the object.',
'object': 'This parameter represents the entity in the context directly impacted by or involved in the action, typically the recipient or target of the main verb's action.',
'subject_type': 'The category of the subject entity e.g. location, person, event, material, process etc.',
'object_type': 'The category of the object entity e.g. location, person, event, material, process etc.',
'predicate_type': 'The category of the predicate e.g. causative, action, ownership, occurance etc.'
}},
]
"""
messages_classify_entity = [
{"role": "user", "content": identify_entity_message},
{"role": "user", "content": "Query : In the context of a semantic triple framework, first identify which entity is subject and which is the object and also validate and output their their respective types. Also determine the predicate and predicate type."},
]
elif self.user_context and self.fixed_entities :
identify_entity_message = f"""Please analyze the provided context below. Once you have understood the context, answer the user query using the specified output format.
Expand Down Expand Up @@ -253,14 +277,20 @@ def generate_output_tuple(self,result, context_json):
)

return output_tuple


def extract_key(tup):
subject, json_string, obj = tup
data = json.loads(json_string.replace("\n", ""))
return (subject, data.get('predicate'), obj)

async def process_tokens(self, data: IngestedTokens):
try:
if not GPTLLM.validate_ingested_tokens(data):
self.set_termination_event()
return
relationships = []
result = await self.llm_instance.process_tokens(data)
unique_keys = set()
result = await self.llm_instance.process_tokens(data)
if not result: return
else:
filtered_triples, file = result
Expand All @@ -275,12 +305,14 @@ async def process_tokens(self, data: IngestedTokens):
result = await self.process_triples(context, entity1_nn_chunk, entity2_nn_chunk, entity1_label, entity2_label)
if result:
output_tuple = self.generate_output_tuple(result, context_json)
relationships.append(output_tuple)
key = GPTLLM.extract_key(output_tuple)
if key not in unique_keys:
unique_keys.add(key)
relationships.append(output_tuple)
if len(relationships) > 0:
if self.fixed_relationships and self.sample_relationships:
embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True, generate_embeddings_with_fixed_relationship = True)
elif self.sample_relationships:
print("Only for sample_relationships")
embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True)
else:
embedding_triples = self.create_emb.generate_embeddings(relationships)
Expand Down
Loading

0 comments on commit 9f73b51

Please sign in to comment.