Skip to content

Commit

Permalink
added prints and traverser logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ngupta10 committed Apr 25, 2024
1 parent 1576a8b commit 3220b61
Show file tree
Hide file tree
Showing 8 changed files with 368 additions and 2 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,7 @@ lib/vis-9.1.2/vis-network.css
lib/vis-9.1.2/vis-network.min.js
tests/data/llm/cleaned_graph_event (copy).csv
tests/data/llm/cleaned_graph_event1.csv
graph.png
my_subgraph_data.csv
subgraph_output_2.csv
subgraph_output.csv
10 changes: 9 additions & 1 deletion querent/core/transformers/bert_ner_opensourcellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ async def process_tokens(self, data: IngestedTokens):
content = clean_text
file = data.get_file_path()
if content:
print("BERT Content -----------------------", content)
if self.fixed_entities:
content = self.entity_context_extractor.find_entity_sentences(content)
tokens = self.ner_llm_instance._tokenize_and_chunk(content)
Expand All @@ -158,9 +159,11 @@ async def process_tokens(self, data: IngestedTokens):
number_sentences = number_sentences + 1
else:
return
print("Doc Entity Pairss-------------------", doc_entity_pairs)
if self.sample_entities:
doc_entity_pairs = self.entity_context_extractor.process_entity_types(doc_entities=doc_entity_pairs)
if any(doc_entity_pairs):
print("Found odc entity pairssssssssssssssssssssssssssssssss")
doc_entity_pairs = self.ner_llm_instance.remove_duplicates(doc_entity_pairs)
pairs_withattn = self.attn_scores_instance.extract_and_append_attention_weights(doc_entity_pairs)
if self.enable_filtering == True and not self.entity_context_extractor and self.count_entity_pairs(pairs_withattn)>1 and not self.predicate_context_extractor:
Expand All @@ -184,8 +187,11 @@ async def process_tokens(self, data: IngestedTokens):
if not filtered_triples:
return
elif not self.skip_inferences:
relationships = self.semantic_extractor.process_tokens(filtered_triples, fixed_entities=(len(self.sample_entities) >= 1))
print("Going to run BERT")
relationships = self.semantic_extractor.process_tokens(filtered_triples[:5], fixed_entities=(len(self.sample_entities) >= 1))
print ("Found these relationshipssssssss ----", relationships)
relationships = self.semantictriplefilter.filter_triples(relationships)
print ("Found these relationshipssssssss ----", relationships)
if len(relationships) > 0:
if self.fixed_relationships and self.sample_relationships:
embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True, generate_embeddings_with_fixed_relationship = True)
Expand All @@ -194,6 +200,7 @@ async def process_tokens(self, data: IngestedTokens):
else:
embedding_triples = self.create_emb.generate_embeddings(relationships)
if self.sample_relationships:
print("Inside BERT going to compute scortessssss")
embedding_triples = self.predicate_context_extractor.update_embedding_triples_with_similarity(self.predicate_json_emb, embedding_triples)
for triple in embedding_triples:
if not self.termination_event.is_set():
Expand All @@ -214,4 +221,5 @@ async def process_tokens(self, data: IngestedTokens):
else:
return
except Exception as e:
print("Exception in BERT: -----------------------------", e)
self.logger.debug(f"Invalid {self.__class__.__name__} configuration. Unable to process tokens. {e}")
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ async def process_triples(self, context, entity1, entity2, entity1_label, entity
{"role": "user", "content": identify_entity_message},
{"role": "user", "content": self.user_context},
]
print("GPT LLM prompt message -------------------------", messages_classify_entity)
identify_predicate_response = self.generate_response(
messages_classify_entity,
"predicate_info"
Expand Down Expand Up @@ -292,11 +293,13 @@ async def process_tokens(self, data: IngestedTokens):
doc_source = data.doc_source
relationships = []
unique_keys = set()
print("Inside GPT-----------------------")
result = await self.llm_instance.process_tokens(data)
if not result: return
else:
filtered_triples, file = result
modified_data = GPTLLM.remove_items_from_tuples(filtered_triples)
print("Data in GPT------------------------", modified_data[:1])
for entity1, context_json, entity2 in modified_data:
context_data = json.loads(context_json)
context = context_data.get("context", "")
Expand All @@ -313,12 +316,15 @@ async def process_tokens(self, data: IngestedTokens):
relationships.append(output_tuple)
if len(relationships) > 0:
if self.fixed_relationships and self.sample_relationships:
print("Both are settttttttttttttttttttt-----")
embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True, generate_embeddings_with_fixed_relationship = True)
elif self.sample_relationships:
print("Only Sample Relationships are settttttttttttttttttttttttt-----")
embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True)
else:
embedding_triples = self.create_emb.generate_embeddings(relationships)
if self.sample_relationships:
print("Going to compute scores------------------------------")
embedding_triples = self.predicate_context_extractor.update_embedding_triples_with_similarity(self.predicate_json_emb, embedding_triples)
for triple in embedding_triples:
if not self.termination_event.is_set():
Expand All @@ -335,6 +341,7 @@ async def process_tokens(self, data: IngestedTokens):
else:
return
except Exception as e:
print("Exception in GPT-----------------------", e)
self.logger.error(f"Invalid {self.__class__.__name__} configuration. Unable to extract predicates using GPT. {e}")
raise Exception(f"An error occurred while extracting predicates using GPT: {e}")

Expand Down
4 changes: 4 additions & 0 deletions querent/kg/ner_helperfunctions/fixed_predicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def update_embedding_triples_with_similarity(self, predicate_json_emb, embedding
predicate_emb_list = [item["predicate_emb"] for item in predicate_json_emb if item["predicate_emb"] != "Not Implemented"]
predicate_emb_matrix = np.array(predicate_emb_list)
updated_embedding_triples = []
print("Embedding Triples withSimilarity----", embedding_triples)
print("Predicate Matrixxxxxxxxx", predicate_emb_matrix)
for triple in embedding_triples:
entity, triple_json, study_field = triple
triple_data = json.loads(triple_json)
Expand All @@ -173,12 +175,14 @@ def update_embedding_triples_with_similarity(self, predicate_json_emb, embedding
similarities = cosine_similarity(current_predicate_emb, predicate_emb_matrix)
max_similarity_index = np.argmax(similarities)
most_similar_predicate_details = predicate_json_emb[max_similarity_index]
print("Max similarity index -------", similarities)
if similarities[0][max_similarity_index] > 0.5:
triple_data["predicate_type"] = most_similar_predicate_details["type"]
if most_similar_predicate_details["relationship"].lower() != "unlabelled":
triple_data["predicate"] = most_similar_predicate_details["relationship"]
updated_triple_json = json.dumps(triple_data)
updated_embedding_triples.append((entity, updated_triple_json, study_field))
print("updated_embedding_triples------------", updated_embedding_triples)
return updated_embedding_triples
except Exception as e:
raise Exception(f"Error processing predicate types: {e}")
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@

setup(
name="querent",
version="3.0.3",
version="3.0.4",
author="Querent AI",
description="The Asynchronous Data Dynamo and Graph Neural Network Catalyst",
long_description=long_description,
Expand Down
Empty file added tests/traverser/__init__.py
Empty file.
108 changes: 108 additions & 0 deletions tests/traverser/kge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
import pandas as pd
from torch.nn.functional import cosine_similarity
import random
import numpy as np

# Set seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


class TextEnhancedKGE(nn.Module):
def __init__(self, entity_dim, relation_dim, entity_to_idx, relation_to_idx, bert_model_name='bert-base-uncased'):
super(TextEnhancedKGE, self).__init__()
self.entity_embeddings = nn.Embedding(len(entity_to_idx), entity_dim)
self.relation_embeddings = nn.Embedding(len(relation_to_idx), relation_dim)
self.sentence_projection = nn.Linear(768, relation_dim)
self.bert_model = BertModel.from_pretrained(bert_model_name)
self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
self.score_layer = nn.Linear(entity_dim * 2 + relation_dim, 1)
self.combined_projection = nn.Linear(entity_dim * 2 + relation_dim, 768)

nn.init.xavier_uniform_(self.entity_embeddings.weight)
nn.init.xavier_uniform_(self.relation_embeddings.weight)
nn.init.xavier_uniform_(self.sentence_projection.weight)
nn.init.xavier_uniform_(self.combined_projection.weight)

def sentence_to_embedding(self, sentences):
inputs = self.bert_tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=500)
outputs = self.bert_model(**inputs)
return outputs.last_hidden_state[:, 0, :].squeeze()

def forward(self, heads, relations, tails, sentences):
print("Heads: ", heads)
print("Relations: ", relations)
print("Tensors: ", tails)
print("Sets: ", sentences)
head_embeddings = self.entity_embeddings(heads)
print("head embeddings", head_embeddings)
relation_embeddings = self.relation_embeddings(relations)
tail_embeddings = self.entity_embeddings(tails)

sentence_embeddings = self.sentence_to_embedding(sentences)
projected_sentences = self.sentence_projection(sentence_embeddings)

score = self.calculate_score(head_embeddings, relation_embeddings, tail_embeddings, projected_sentences)
return score

def calculate_score(self, head_embeddings, relation_embeddings, tail_embeddings, sentence_embeddings):
combined_embeddings = torch.cat([head_embeddings, relation_embeddings + sentence_embeddings, tail_embeddings], dim=1)
return self.score_layer(combined_embeddings)

def query(self, query_text, heads, relations, tails, sentences):
query_embedding = self.sentence_to_embedding([query_text]).unsqueeze(0)
sentence_embeddings = self.sentence_to_embedding(sentences)
projected_sentences = self.sentence_projection(sentence_embeddings)

head_embeddings = self.entity_embeddings(heads)
relation_embeddings = self.relation_embeddings(relations)
tail_embeddings = self.entity_embeddings(tails)

relation_plus_sentence = relation_embeddings + projected_sentences
combined_embeddings = torch.cat([head_embeddings, relation_plus_sentence, tail_embeddings], dim=1)

if combined_embeddings.shape[-1] != query_embedding.shape[-1]:
combined_embeddings = self.combined_projection(combined_embeddings)

similarities = cosine_similarity(query_embedding, combined_embeddings, dim=-1)
return similarities

# Load your data
data = pd.read_csv('my_subgraph_data.csv')
print(data.head())
print(data.columns)
# Create mappings
entity_to_idx = {entity: idx for idx, entity in enumerate(pd.concat([data['Node Start'], data['Node End']]).unique())}
relation_to_idx = {relation: idx for idx, relation in enumerate(data['Relationship Type'].unique())}
print('Creating entity_to_idx and relation_to_idx', entity_to_idx)
print('2nd indx', relation_to_idx)

# Initialize the model
model = TextEnhancedKGE(
entity_dim=100,
relation_dim=100,
entity_to_idx=entity_to_idx,
relation_to_idx=relation_to_idx
)

# Prepare data for the model
heads = torch.LongTensor(data['Node Start'].map(entity_to_idx).values)
relations = torch.LongTensor(data['Relationship Type'].map(relation_to_idx).values)
tails = torch.LongTensor(data['Node End'].map(entity_to_idx).values)
sentences = data['Sentence'].tolist()

# Calculate scores
scores = model(heads, relations, tails, sentences)
print(scores)

# User query
query_text = "How does hydraulic fracturing enhance porosity?"
similarities = model.query(query_text, heads, relations, tails, sentences)
top_matches = similarities.topk(10) # Get the top 10 matches as per revised requirement
print(top_matches)
Loading

0 comments on commit 3220b61

Please sign in to comment.