Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global and Local structure #16

Open
CaptainCuong opened this issue Oct 30, 2022 · 3 comments
Open

Global and Local structure #16

CaptainCuong opened this issue Oct 30, 2022 · 3 comments

Comments

@CaptainCuong
Copy link

CaptainCuong commented Oct 30, 2022

Hi Minkai Xu,

What's the motivation for that you designed two separate architectures to learn local and global structures?
In loss, the loss is divided into local loss and global loss (node_eq_global - target_pos_global)**2 + (node_eq_local - target_pos_local)**2

@MinkaiXu
Copy link
Owner

MinkaiXu commented Nov 1, 2022

Hi,

Two encoders tend to encode real bonds and virtual bonds within a threshold respectively, which in practice helps the performance.

@mulangonando
Copy link

Hi Minkai,

I am trying to reproduce the results of your GeoDiff paper.
The code runs and trains fine, I am now generating the samples.

However, to help understand the code better, I am looking at the losses; I see 3 losses calculated using an objective close to the RMSD,

What is the difference between the local and global loss. I see they are both based on some:
edge_index=edge_index,
edge_length=edge_length,
edge_attr=edge_attr_global
Both of these are defined by the three attributes here. The only difference is that the local lose is clipped using some local_edge_max.

  # Local
    node_attr_local = self.encoder_local(
        z=atom_type,
        edge_index=edge_index[:, local_edge_mask],
        edge_attr=edge_attr_local[local_edge_mask],
    )

Following the local edge, max I find it is based on whether an edge type is local or not.

local_edge_mask = is_local_edge(edge_type) # (E, )

def is_local_edge(edge_type):
return edge_type > 0

Now this does not make much sense. I don’t know what edge types are > 0 ?? Could you explain something about the edge_type and the local_edge_mask that either makes them local or global e.g. what are the different edge types? And what numbers are they given as per the code.

IN THE train.py SCRIPT
batch = next(train_iterator).to(args.device)
loss, loss_global, loss_local = model.get_loss(
atom_type=batch.atom_type,
pos=batch.pos,
bond_index=batch.edge_index,
bond_type=batch.edge_type,
batch=batch.batch,
num_nodes_per_graph=batch.num_nodes_per_graph,
num_graphs=batch.num_graphs,
anneal_power=config.train.anneal_power,
return_unreduced_loss=True
)

I see we load the batch to cuda using the torch_geometric.data Dataset,
For some reason the batch object has these properties like pos, atom_type? Could you explain this as well.
I see the local/global depends somehow on the pos parameter.

Your help is highly appreciated.

@MinkaiXu
Copy link
Owner

MinkaiXu commented Apr 5, 2023

Hi @mulangonando,

Thanks for your interest!

Pos and atom_type are atom position and type respectively. Local edges are the bonds existing in molecular graphs, while global edges refer to those added when two atoms are close enough. Global edges are labeled as type 0.

I recommend that you can have a closer look at the module at

GeoDiff/models/common.py

Lines 231 to 254 in c6f26dc

def extend_graph_order_radius(num_nodes, pos, edge_index, edge_type, batch, order=3, cutoff=10.0,
extend_order=True, extend_radius=True, is_sidechain=None):
if extend_order:
edge_index, edge_type = _extend_graph_order(
num_nodes=num_nodes,
edge_index=edge_index,
edge_type=edge_type, order=order
)
# edge_index_order = edge_index
# edge_type_order = edge_type
if extend_radius:
edge_index, edge_type = _extend_to_radius_graph(
pos=pos,
edge_index=edge_index,
edge_type=edge_type,
cutoff=cutoff,
batch=batch,
is_sidechain=is_sidechain
)
return edge_index, edge_type
, where we calculated the local and global edges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants