-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add fixes from the maxent paper (#116)
* add fixes from the maxent paper * fix style * fix style pt.2 * style * rename print=True to print_hps * fix qm9 problems * docu * rename chunked sim * move traj len out of reward percentilehook * remove ruamel * tox * add flag to store all checkpoints * fix moohook in seh_frag and remove it in qm9 * add comment about the graceful termination of moostats * add a flag for predicting n * REMOVE USELESS QM9 THING * fix typo * upgrade qm9 * fmt * fmt * broadcast back the invalid results * add compute_reward_from_graph method to seh * use compute_reward_from_graph in seh_moo * unify trminate and to_close * fmt * ft * f * fix typo * fix runtime errors * fmt * close hdf5 gracefully * fmt * revert default num_graph_out
- Loading branch information
Showing
25 changed files
with
1,028 additions
and
201 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,88 @@ | ||
import sys | ||
import tarfile | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import rdkit.Chem as Chem | ||
import torch | ||
from rdkit.Chem import QED, Descriptors | ||
from torch.utils.data import Dataset | ||
|
||
from gflownet.utils import sascore | ||
|
||
|
||
class QM9Dataset(Dataset): | ||
def __init__(self, h5_file=None, xyz_file=None, train=True, target="gap", split_seed=142857, ratio=0.9): | ||
def __init__(self, h5_file=None, xyz_file=None, train=True, targets=["gap"], split_seed=142857, ratio=0.9): | ||
if h5_file is not None: | ||
self.df = pd.HDFStore(h5_file, "r")["df"] | ||
|
||
self.hdf = pd.HDFStore(h5_file, "r") | ||
self.df = self.hdf["df"] | ||
self.is_hdf = True | ||
elif xyz_file is not None: | ||
self.load_tar() | ||
self.df = load_tar(xyz_file) | ||
self.is_hdf = False | ||
else: | ||
raise ValueError("Either h5_file or xyz_file must be provided") | ||
rng = np.random.default_rng(split_seed) | ||
idcs = np.arange(len(self.df)) # TODO: error if there is no h5_file provided. Should h5 be required | ||
idcs = np.arange(len(self.df)) | ||
rng.shuffle(idcs) | ||
self.target = target | ||
self.targets = targets | ||
if train: | ||
self.idcs = idcs[: int(np.floor(ratio * len(self.df)))] | ||
else: | ||
self.idcs = idcs[int(np.floor(ratio * len(self.df))) :] | ||
self.mol_to_graph = lambda x: x | ||
|
||
def get_stats(self, percentile=0.95): | ||
y = self.df[self.target] | ||
return y.min(), y.max(), np.sort(y)[int(y.shape[0] * percentile)] | ||
def setup(self, task, ctx): | ||
self.mol_to_graph = ctx.mol_to_graph | ||
|
||
def load_tar(self, xyz_file): | ||
f = tarfile.TarFile(xyz_file, "r") | ||
labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"] | ||
all_mols = [] | ||
for pt in f: | ||
pt = f.extractfile(pt) | ||
data = pt.read().decode().splitlines() | ||
all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:]))) | ||
self.df = pd.DataFrame(all_mols, columns=["SMILES"] + labels) | ||
def get_stats(self, target=None, percentile=0.95): | ||
if target is None: | ||
target = self.targets[0] | ||
y = self.df[target] | ||
return y.min(), y.max(), np.sort(y)[int(y.shape[0] * percentile)] | ||
|
||
def __len__(self): | ||
return len(self.idcs) | ||
|
||
def __getitem__(self, idx): | ||
return ( | ||
Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]), | ||
torch.tensor([self.df[self.target][self.idcs[idx]]]).float(), | ||
self.mol_to_graph(Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]])), | ||
torch.tensor([self.df[t][self.idcs[idx]] for t in self.targets]).float(), | ||
) | ||
|
||
def terminate(self): | ||
if self.is_hdf: | ||
self.hdf.close() | ||
|
||
def convert_h5(): | ||
# File obtained from | ||
# https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904 | ||
# (from http://quantum-machine.org/datasets/) | ||
f = tarfile.TarFile("qm9.xyz.tar", "r") | ||
|
||
def load_tar(xyz_file): | ||
labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"] | ||
f = tarfile.TarFile(xyz_file, "r") | ||
all_mols = [] | ||
for pt in f: | ||
pt = f.extractfile(pt) # type: ignore | ||
data = pt.read().decode().splitlines() # type: ignore | ||
all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:]))) | ||
df = pd.DataFrame(all_mols, columns=["SMILES"] + labels) | ||
store = pd.HDFStore("qm9.h5", "w") | ||
store["df"] = df | ||
mols = df["SMILES"].map(Chem.MolFromSmiles) | ||
df["qed"] = mols.map(QED.qed) | ||
df["sa"] = mols.map(sascore.calculateScore) | ||
df["mw"] = mols.map(Descriptors.MolWt) | ||
return df | ||
|
||
|
||
def convert_h5(xyz_file="qm9.xyz.tar", h5_file="qm9.h5"): | ||
""" | ||
Convert `xyz_file` and dump it into `h5_file` | ||
""" | ||
# File obtained from | ||
# https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904 | ||
# (from http://quantum-machine.org/datasets/) | ||
df = load_tar(xyz_file) | ||
with pd.HDFStore(h5_file, "w") as store: | ||
store["df"] = df | ||
|
||
|
||
if __name__ == "__main__": | ||
convert_h5(*sys.argv[1:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.