Skip to content

Commit

Permalink
Implementing Fixed Relationship matching logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ngupta10 committed Apr 10, 2024
1 parent f15d3bb commit 0c862a3
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 12 deletions.
6 changes: 5 additions & 1 deletion querent/core/transformers/bert_ner_opensourcellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,12 @@ def __init__(
raise ValueError("If specific predicates are provided, their types should also be provided.")
if self.fixed_relationships and self.sample_relationships:
self.predicate_context_extractor = FixedPredicateExtractor(fixed_predicates=self.fixed_relationships, predicate_types=self.sample_relationships,model = self.nlp_model)
self.predicate_json = self.predicate_context_extractor.construct_predicate_json(self.fixed_relationships, self.sample_relationships)
self.predicate_json_emb = self.create_emb.generate_relationship_embeddings(self.predicate_json)
elif self.sample_relationships:
self.predicate_context_extractor = FixedPredicateExtractor(predicate_types=self.sample_relationships,model = self.nlp_model)
self.predicate_json = self.predicate_context_extractor.construct_predicate_json(self.sample_relationships)
self.predicate_json_emb = self.create_emb.generate_relationship_embeddings(self.predicate_json)
else:
self.predicate_context_extractor = None
self.user_context = config.user_context
Expand Down Expand Up @@ -175,7 +179,7 @@ async def process_tokens(self, data: IngestedTokens):
return
if self.sample_entities:
doc_entity_pairs = self.entity_context_extractor.process_entity_types(doc_entities=doc_entity_pairs)
if doc_entity_pairs and any(doc_entity_pairs):
if any(doc_entity_pairs):
doc_entity_pairs = self.ner_llm_instance.remove_duplicates(doc_entity_pairs)
pairs_withattn = self.attn_scores_instance.extract_and_append_attention_weights(doc_entity_pairs)
if self.enable_filtering == True and not self.entity_context_extractor and self.count_entity_pairs(pairs_withattn)>1 and not self.predicate_context_extractor:
Expand Down
44 changes: 43 additions & 1 deletion querent/kg/ner_helperfunctions/fixed_predicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import List
from nltk.corpus import wordnet as wn
import json
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

"""
FixedPredicateExtractor is a class designed for extracting sentences containing specific predicates or predicate types from text. It utilizes spaCy for natural language processing and WordNet for synonym expansion.
Expand Down Expand Up @@ -133,4 +135,44 @@ def process_predicate_types(self, doc_predicates):

return filtered_predicates
except Exception as e:
raise Exception(f"Error processing predicate types: {e}")
raise Exception(f"Error processing predicate types: {e}")

def construct_predicate_json(relationships=None, relationship_types=None):
predicate_values = []
if relationships and relationship_types:
if len(relationships) != len(relationship_types):
raise Exception("'relationships' and 'relationship_types' lists must have the same length.")
for relationship, relationship_type in zip(relationships, relationship_types):
predicate_value = f"{relationship} ({relationship_type})"
predicate_values.append({"predicate_value": predicate_value, "relationship": relationship, "type": relationship_type})
elif relationship_types:
for relationship_type in relationship_types:
predicate_values.append({"predicate_value": relationship_type, "type": relationship_type})
else:

return

return json.dumps(predicate_values)



def update_embedding_triples_with_similarity(predicate_json_emb, embedding_triples):
predicate_json_emb = [json.loads(item) for item in predicate_json_emb]
embedding_triples = [json.loads(item) for item in embedding_triples]
predicate_emb_list = [item["predicate_emb"] for item in predicate_json_emb if item["predicate_emb"] != "Not Implemented"]
predicate_emb_matrix = np.array(predicate_emb_list)
for triple in embedding_triples:
if triple["predicate_emb"] == "Not Implemented":

continue

current_predicate_emb = np.array(triple["predicate_emb"]).reshape(1, -1)
similarities = cosine_similarity(current_predicate_emb, predicate_emb_matrix)
max_similarity_index = np.argmax(similarities)
most_similar_predicate_details = predicate_json_emb[max_similarity_index]
triple["predicate_type"] = most_similar_predicate_details["type"]
if most_similar_predicate_details["relationship"].lower() != "unlabelled":
triple["predicate"] = most_similar_predicate_details["relationship"]
updated_embedding_triples = [json.dumps(item) for item in embedding_triples]

return updated_embedding_triples
50 changes: 40 additions & 10 deletions querent/kg/rel_helperfunctions/embedding_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,13 @@ def get_embeddings(self, texts):
else:
payload = {"inputs": text}
embedding = self.query(payload)
if isinstance(self.embeddings,HuggingFaceEmbeddings) or isinstance(self.embeddings, HuggingFaceInferenceAPIEmbeddings) :
embedding = self.embeddings.embed_query(text)
embeddings.append(embedding)
else:
payload = {"inputs": text}
embedding = self.query(payload)
return embeddings
except Exception as e:
self.logger.error(f"Failed to generate embeddings: {e}")
raise Exception(f"Failed to generate embeddings: {e}")


def generate_embeddings(self, payload):
def generate_embeddings(self, payload, relationship_finder=False, generate_embeddings_with_fixed_relationship = False):
try:
triples = payload
processed_pairs = []
Expand All @@ -159,14 +153,23 @@ def generate_embeddings(self, payload):
predicate_type = data.get("predicate_type","Unlabeled").replace('"', '\\"')
subject_type = data.get("subject_type","Unlabeled").replace('"', '\\"')
object_type = data.get("object_type","Unlabeled").replace('"', '\\"')
context_embeddings = self.get_embeddings([context])[0]
context_embeddings = None
predicate_embedding = None
if not relationship_finder:
context_embeddings = self.get_embeddings([context])[0]
else:
if generate_embeddings_with_fixed_relationship:
predicate_embedding = self.get_embeddings([predicate + " ("+predicate_type+")"])[0]
else:
predicate_embedding = self.get_embeddings([predicate_type])[0]
essential_data = {
"context": context,
"context_embeddings" : context_embeddings,
"predicate_type": predicate_type,
"predicate" : predicate,
"subject_type": subject_type,
"object_type": object_type
"object_type": object_type,
"predicate_emb": predicate_embedding if predicate_embedding is not None else "Not Implemented"
}
updated_json_string = json.dumps(essential_data)
processed_pairs.append((entity, updated_json_string, related_entity))
Expand All @@ -178,4 +181,31 @@ def generate_embeddings(self, payload):
except Exception as e:
self.logger.error(f"Error in extracting embeddings: {e}")


def generate_relationship_embeddings(self, payload):
try:
relationships = payload
processed_pairs = []

for relation in relationships:
try:
data = json.loads(relation)
predicate_value = data.get("predicate_value", "").replace('"', '\\"')
relationship = data.get("relationship","unlabelled").replace('"', '\\"')
relationship_type = data.get("type").replace('"', '\\"')
predicate_embedding = None
predicate_embedding = self.get_embeddings([predicate_value])[0]
essential_data = {
"predicate_value": predicate_value,
"predicate_emb" : predicate_embedding,
"relationship" : relationship,
"type" : relationship_type
}
updated_json_string = json.dumps(essential_data)
processed_pairs.append((updated_json_string))
except json.JSONDecodeError as e:
self.logger.debug(f"JSON parsing error while generating embeddings for fixed realtionships: {e} in string.")

return processed_pairs

except Exception as e:
self.logger.error(f"Error in extracting embeddings for fixed realtionships: {e}")

0 comments on commit 0c862a3

Please sign in to comment.