Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

用unimol+模型 inference出的pos_pred取同一个id的均值吗? #269

Open
chenkk717 opened this issue Sep 10, 2024 · 8 comments
Open

Comments

@chenkk717
Copy link

在make_pcq_test_dev_submission.py文件中gap_pred是取同一个id对应值的平均,大多数id是8个值的平均。然而在预测出的pos_pred(原子坐标信息中),为什么相同id得到的pos_pred形状会不一样?也就是说同一个分子的原子数量对不上?如下是某个id预测出的pos_pred
test
形状为Shape mismatch for id group: [(19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (15, 3)],第8个预测出来的pos_pred只有15个原子对应的坐标信息,而其它为19个原子。
想请问这是什么情况?以及想得到某个id的最终pos_pred值,是采用什么方法?除去有异常原子数量的预测值,然后其余取均值吗?

@chenkk717
Copy link
Author

同时想请问一下原paper(Data-driven quantum chemical property prediction leveraging 3D conformationswith Uni-Mol+)里Fig. 2
image
中predicted conformations的sdf文件是如何得来的?因为模型预测后只有原子的坐标信息,是根据smiles使用RDKit生成mol对象然后再把pos_pred代入吗?
image

@ShuqiLu
Copy link
Collaborator

ShuqiLu commented Sep 10, 2024

您好,请问第一个问题的id是什么,我可以先复现一下。然后也确认下这个id对应的构象是8个吗,因为不是所有的id都有8个构象;第二个问题确实是的,是把mol对象的坐标换成预测的坐标来生成的

@chenkk717
Copy link
Author

①第一个问题里的id在图片的上方:3379091,我是运行的test-dev split这个测试集,实际上不止这一个id有这个情况,截取部分预测原子数量不一致的id:
image
image
②想请问最终的final pos_pred是取同一个id的所有预测pos_pred的均值吗?还是按照某种规则挑选其一?
③smiles中如果含有H原子,那么在用预测的坐标替换掉由smiles结构生成的mol对象坐标时,氢原子的坐标信息应该怎么处理呢?因为在get_3d_lmdb.py对数据集的预处理文件中,我发现在def rdkit_mmff(mol)函数中return rdkit_remove_hs(mol)也就是input_pos或者label_pos都是不含有H原子信息的。在得到sdf文件时,是在smiles生成mol对象后再进行了H原子的移除吗?然后再代入预测坐标值

@ShuqiLu
Copy link
Collaborator

ShuqiLu commented Sep 11, 2024

1:不确定这里是怎么输出的每个id对应的构象坐标shape,如果可以的话可以提供一下相应的代码。不过不是所有id对应的构象都是8个,如果按照给出的图里全都是按照8个来切分,可能混淆了不同id的分子,所以原子数不一样。
2. 这篇论文里我们最终是为了获得预测的homo lumo gap,所以没有返回最终的预测坐标,在得到不同构象下的pos_pred之后没有再做处理了。
3. rdkit_remove_hs(mol)返回的是一个不带H的mol对象,我们show case的时候是基于这个mol对象进行的坐标替换再展示画图,过程中都没有H。

@chenkk717
Copy link
Author

chenkk717 commented Sep 11, 2024

1.发现相同id对应的预测构象坐标shape不同后,我检查了模型预测后得到的直接输出test-dev_0.pkl,以防是我输出shape的代码有误,check预测的坐标值(以id3379091为例子)结果如下:
image
其中第8个构象的原子数确实为15,而其它为19,和shape的输出[(19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (15, 3)]吻合,详细的数据见
id3379091_pos_pred.txt
这部分的代码如下:

import numpy as np
import torch
import pickle
import glob
import pandas as pd

input_folder = "results"
subset = "test-dev"
split = torch.load("./scripts/pcqm4m-v2/split_dict.pt")
valid_index = split[subset]

def flatten(d, index):
    res = []
    for x in d:
        res.extend(x[index])
    return np.array(res)

# 提取并处理每个id的所有原子坐标
def one_ckp(folder, subset):
    s = f"{folder}/" + subset + "*.pkl"
    files = sorted(glob.glob(s))
    data = []
    for file in files:
        with open(file, "rb") as f:
            try:
                data.extend(pickle.load(f))
            except Exception as e:
                print("Error in file: ", file)
                raise e

    # 提取 id 和 pos_pred
    id = flatten(data, 0)  # 分子id
    pos_pred = flatten(data, 1)  # 该分子所有原子的三维坐标

    # 将数据放入 DataFrame,每个id保留完整的分子坐标信息
    df = pd.DataFrame({"id": id, "pos_pred": list(pos_pred)})

    # 按 id 分组,计算同一个 id 的分子坐标均值(多个预测结果的均值)
    df_grouped = df.groupby("id")
    # 调试函数:检查同一个 id 下的 pos_pred 是否具有相同的形状
    def check_shapes_and_mean(x):
        shapes = [arr.shape for arr in x]
        if not all(shape == shapes[0] for shape in shapes):
            print(f"Shape mismatch for id {x.name} group: {shapes}")
            return None
        return np.mean(np.stack(x), axis=0)
    df_mean = df_grouped["pos_pred"].apply(check_shapes_and_mean)
    return df_mean

#保存结果为 Parquet 格式
def save_pos_submission_parquet(df_grouped, output_file):
    # 将每个分子的平均坐标保存为 Parquet 文件
    df_grouped = pd.DataFrame(df_grouped.tolist(), index=df_grouped.index, columns=["pos_pred"])
    df_grouped.to_parquet(output_file, index=True, compression='snappy')
    print(f"Saved position predictions to {output_file}")

df_mean_pos = one_ckp(input_folder, subset)
output_file = "pos_pred_mean.parquet"
save_pos_submission_parquet(df_mean_pos, output_file)

我想代码中是以 df_grouped = df.groupby("id")按id分组的,并不是以8个为一组,如果是切分有误,那么问题可能发生在我的模型inference过程中,但是inference.py文件我并无改动,想知道您的test-dev测试集中该id的pos_pred结果原子数量有异吗?
2.关于final pos_pred的选择问题,我下载了本文提供的Supplementary_Data中conformation_compare_fig2的sdf文件,如下:
image
分子id后的数字是表示选取的第几个conformation吗?对于Fig. 2中id为3388743是第4个而id为3428088是第2个,以此类推。所以想请问这里的选择标准是什么?是根据可视化之后的RMSD吗?然后选择RMSD最小的一个构象作为展示?

@chenkk717
Copy link
Author

chenkk717 commented Sep 19, 2024

我重新检查了unimol+模型的inference过程。在unimol_plus文件夹下的pcq.py中的load_dataset函数部分,涉及到pcq_dataset.py文件中的PCQDataset函数,其中以下代码:

        max_node_num = max([item["atom_mask"].shape[0] for item in items])
        max_node_num = (max_node_num + 1 + 3) // 4 * 4 - 1
        batched_data = {}
        for key in items[0].keys():
            samples = [item[key] for item in items]
            if key in pad_fns:
                batched_data[key] = pad_fns[key](samples, max_node_num)

max_node_num 是每个batch里最大分子原子数,然后将其调整到最接近的 4 的倍数减去 1。atom_mask对分子的真实原子位赋1,新添加的虚拟原子位赋0,包括后面涉及到的attn_mask对新添的虚拟原子位赋-inf。经过这种处理,新添的虚拟原子坐标值一开始都是0,但是在经过inference后,虚拟原子(mask标记为0)的坐标也有了预测值,导致我上面出现的同一个id(相同分子)对应的原子数不相同的情况(原子个数为每个batch里的max_node_num)。
因为有的分子不一定是生成8个conformers,batch_size是设定的为8的倍数,就会导致同一个batch里可能包含不同的分子id,然而预测过程中,mask即使为0的原子也有了坐标的预测输出,所以同一batch里面的分子原子数都是固定的max_node_num。不知道我的理解是否有偏差,希望能解答一下疑惑。

@chenkk717
Copy link
Author

您好,请问复现有结果了吗?不知道我上述猜想是否正确? @ShuqiLu

@ShuqiLu
Copy link
Collaborator

ShuqiLu commented Oct 27, 2024

不好意思没有看到您的回复,你理解的其实也差不多,这里把所有分子的原子数都置为max_node_num是为了能用pytorch并行处理一个batch内的所有分子,需要所有tensor的shape一致,所以这里用padding操作,把batch的的分子的原子数补充成相同的数目;为了使得padding的内容不实际影响模型的运算结果,所以使用atom_mask和attn_mask让padding的内容不参与实际运算;因为这篇工作我们只预测分子的能量并不取出分子坐标独立研究,所以没有对返回的分子坐标处理padding的部分,所以看起来同一个分子生成的坐标在不同batch内shape不一样。实际上如果需要取出真实原子的坐标,去掉padding的部分,可以利用atom_mask,把每个分子的atom_mask=1的位置对应的坐标取出,就是所有真实原子的坐标; 或者假设真实原子数为k,可以取出前k个坐标,利用数据中的smiles生成rdkit初始构象,再将前k个坐标填入即得到预测的3d构象(需要原始数据中的smiles不然可能没法对应)。

至于show case中选择的标准是什么,这里其实我们就选择了几个能量预测误差相对小的case展示了一下,没有过多的特殊筛选。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants