Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TheCodehasbeenRestructured #82

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 82 additions & 72 deletions text-graph-grounding/graph_transformer.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,55 @@
from statistics import mean
import torch as t
from torch import nn
import torch.nn.functional as F
import math

init = nn.init.xavier_uniform_
uniformInit = nn.init.uniform


def PositionalEncoding(q_len, d_model, normalize=True):

pe = t.zeros(q_len, d_model)
position = t.arange(0, q_len).unsqueeze(1)
div_term = t.exp(t.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = t.sin(position * div_term)
pe[:, 1::2] = t.cos(position * div_term)

if normalize:
pe = pe - pe.mean()
pe = pe / (pe.std() * 10)
pe = (pe - pe.mean()) / (pe.std() * 10)

return pe


def pos_encoding(pe, learn_pe, nvar, d_model):
# Positional encoding
if pe == None:
W_pos = t.empty((nvar, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe
def pos_encoding(pe_type, learn_pe, nvar, d_model):
"positional encoding."
if pe_type is None:
W_pos = t.empty((nvar, d_model))
nn.init.uniform_(W_pos, -0.02, 0.02)
learn_pe = False
elif pe == "zero":
W_pos = t.empty((nvar, 1))
elif pe_type in ["zero", "zeros"]:
W_pos = t.empty((nvar, d_model if pe_type == "zeros" else 1))
nn.init.uniform_(W_pos, -0.02, 0.02)
elif pe == "zeros":
W_pos = t.empty((nvar, d_model))
nn.init.uniform_(W_pos, -0.02, 0.02)
elif pe == "normal" or pe == "gauss":
elif pe_type in ["normal", "gauss"]:
W_pos = t.zeros((nvar, 1))
t.nn.init.normal_(W_pos, mean=0.0, std=0.1)
elif pe == "uniform":
nn.init.normal_(W_pos, mean=0.0, std=0.1)
elif pe_type == "uniform":
W_pos = t.zeros((nvar, 1))
nn.init.uniform_(W_pos, a=0.0, b=0.1)
elif pe == "sincos":
W_pos = PositionalEncoding(nvar, d_model, normalize=True)
elif pe_type == "sincos":
W_pos = PositionalEncoding(nvar, d_model)
else:
raise ValueError(
f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \
'zeros', 'zero', uniform', 'sincos', None.)"
)
raise ValueError(f"{pe_type} is not a valid positional encoding type.")

return nn.Parameter(W_pos, requires_grad=learn_pe)


class graph_transformer(nn.Module):
class GraphTransformer(nn.Module):
def __init__(self, args):
super(graph_transformer, self).__init__()

self.gtLayers = nn.Sequential(*[GTLayer(args) for i in range(args.gt_layers)])

super(GraphTransformer, self).__init__()

self.gtLayers = nn.Sequential(*[GTLayer(args) for _ in range(args.gt_layers)])
self.W_pos = pos_encoding("zeros", True, args.num_nodes, args.att_d_model)

self.W_P = nn.Linear(args.gnn_input, args.att_d_model)
Expand All @@ -62,77 +58,91 @@ def __init__(self, args):
self.args = args

def forward(self, g):
# Adj: sp adj
# x: bs * n * d_model * num_patch

# print(edge_index)


x = g.x
x, self.W_P.weight, self.W_P.bias, self.W_pos = Mv2Samedevice([x, self.W_P.weight, self.W_P.bias, self.W_pos])
x, self.W_P.weight, self.W_P.bias, self.W_pos = move_to_same_device([x, self.W_P.weight,
self.W_P.bias,
self.W_pos])

z = self.W_P(x)
if self.args.if_pos:
embeds = self.dropout(z + self.W_pos)
else:
embeds = self.dropout(z)

embeds = self.dropout(z + self.W_pos) if self.args.if_pos else self.dropout(z)

for gt in self.gtLayers:
embeds = gt(g, embeds) # bs * num_patch * n * d_model
embeds, self.inverW_P.weight, self.inverW_P.bias = Mv2Samedevice(
embeds = gt(g, embeds)

embeds, self.inverW_P.weight, self.inverW_P.bias = move_to_same_device(
[embeds, self.inverW_P.weight, self.inverW_P.bias]
)
ret = self.inverW_P(embeds)
return ret

return self.inverW_P(embeds)


def Mv2Samedevice(vars):
def move_to_same_device(vars):
return [var.to(vars[0].device) for var in vars]


class GTLayer(nn.Module):
def __init__(self, args):

super(GTLayer, self).__init__()

# Parameter initialization
self.qTrans = nn.Parameter(init(t.empty(args.att_d_model, args.att_d_model)))
self.kTrans = nn.Parameter(init(t.empty(args.att_d_model, args.att_d_model)))
self.vTrans = nn.Parameter(init(t.empty(args.att_d_model, args.att_d_model)))

if args.att_norm:
self.norm = nn.LayerNorm(args.att_d_model, eps=1e-6)
self.norm = nn.LayerNorm(args.att_d_model)

self.args = args

def forward(self, g, embeds):
# Adj: adj
# x: n * d_model

rows, cols = g.edge_index
nvar, _ = embeds.shape
# print(rows)
# print(cols)
nvar = embeds.shape[0]

rowEmbeds = embeds[rows, :]
colEmbeds = embeds[cols, :]
evar, _ = rowEmbeds.shape
rowEmbeds = embeds[rows]
colEmbeds = embeds[cols]

rowEmbeds, self.qTrans, self.kTrans, self.vTrans = Mv2Samedevice(
rowEmbeds, self.qTrans, self.kTrans, self.vTrans = move_to_same_device(
[rowEmbeds, self.qTrans, self.kTrans, self.vTrans]
)
qEmbeds = (rowEmbeds @ self.qTrans).view([evar, self.args.head, self.args.att_d_model // self.args.head])
kEmbeds = (colEmbeds @ self.kTrans).view([evar, self.args.head, self.args.att_d_model // self.args.head])
vEmbeds = (colEmbeds @ self.vTrans).view([evar, self.args.head, self.args.att_d_model // self.args.head])

att = t.einsum("ehd, ehd -> eh", qEmbeds, kEmbeds)
att = t.clamp(att, -10.0, 10.0)
expAtt = t.exp(att)

tem = t.zeros([nvar, self.args.head]).to(expAtt.device)
# print(tem.device, expAtt.device, rows.device)
rows = rows.to(expAtt.device)
attNorm = (tem.index_add_(0, rows, expAtt))[rows, :]
att = expAtt / (attNorm + 1e-8) # bleh

resEmbeds = t.einsum("eh, ehd -> ehd", att, vEmbeds).view([evar, self.args.att_d_model])
tem = t.zeros([nvar, self.args.att_d_model]).to(resEmbeds.device)
rows = rows.to(resEmbeds.device)
resEmbeds = tem.index_add_(0, rows, resEmbeds) # nd
resEmbeds = resEmbeds + embeds
if self.args.att_norm:
resEmbeds, self.norm.weight, self.norm.bias = Mv2Samedevice([resEmbeds, self.norm.weight, self.norm.bias])

evar = rowEmbeds.shape[0]

# Calculate QKV representations
qEmbeds = (rowEmbeds @ self.qTrans).view(evar,
self.args.head,
-1)

kEmbeds = (colEmbeds @ self.kTrans).view(evar,
self.args.head,
-1)

vEmbeds = (colEmbeds @ self.vTrans).view(evar,
self.args.head,
-1)

# Calculate attention scores
att_scores = t.einsum("ehd,ehe->eh", qEmbeds, kEmbeds).clamp(-10.0, 10.0)

expAttScores = t.exp(att_scores)

attNorms = t.zeros(nvar, self.args.head).to(expAttScores.device).index_add_(0,
rows,
expAttScores
)[rows]

att_weights = expAttScores / (attNorms + 1e-8)

resEmbeds = t.einsum("eh,eht->ht", att_weights.unsqueeze(1), vEmbeds).view(evar,
-1)

embeds += resEmbeds

if hasattr(self,'norm'):
resEmbeds = move_to_same_device([resEmbeds])
resEmbeds = self.norm(resEmbeds)

return resEmbeds
return embeds