Skip to content

Commit

Permalink
add fixes from the maxent paper (#116)
Browse files Browse the repository at this point in the history
* add fixes from the maxent paper

* fix style

* fix style pt.2

* style

* rename print=True to print_hps

* fix qm9 problems

* docu

* rename chunked sim

* move traj len out of reward percentilehook

* remove ruamel

* tox

* add flag to store all checkpoints

* fix moohook in seh_frag and remove it in qm9

* add comment about the graceful termination of moostats

* add a flag for predicting n

* REMOVE USELESS QM9 THING

* fix typo

* upgrade qm9

* fmt

* fmt

* broadcast back the invalid results

* add compute_reward_from_graph method to seh

* use compute_reward_from_graph in seh_moo

* unify trminate and to_close

* fmt

* ft

* f

* fix typo

* fix runtime errors

* fmt

* close hdf5 gracefully

* fmt

* revert default num_graph_out
  • Loading branch information
SobhanMP authored Feb 26, 2024
1 parent ef5f2cb commit aa15f27
Show file tree
Hide file tree
Showing 25 changed files with 1,028 additions and 201 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ dependencies = [
"pyro-ppl",
"gpytorch",
"omegaconf>=2.3",
"pandas", # needed for QM9 and HDF5 support.
]

[project.optional-dependencies]
Expand Down
6 changes: 6 additions & 0 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class TBConfig:
Whether to correct for idempotent actions
do_parameterize_p_b : bool
Whether to parameterize the P_B distribution (otherwise it is uniform)
do_predict_n : bool
Whether to predict the number of paths in the graph
do_length_normalize : bool
Whether to normalize the loss by the length of the trajectory
subtb_max_len : int
Expand All @@ -45,6 +47,7 @@ class TBConfig:
variant: TBVariant = TBVariant.TB
do_correct_idempotent: bool = False
do_parameterize_p_b: bool = False
do_predict_n: bool = False
do_sample_p_b: bool = False
do_length_normalize: bool = False
subtb_max_len: int = 128
Expand Down Expand Up @@ -109,6 +112,8 @@ class AlgoConfig:
Idem but for validation, and `self.test_data`.
train_random_action_prob : float
The probability of taking a random action during training
train_det_after: Optional[int]
Do not take random actions after this number of steps
valid_random_action_prob : float
The probability of taking a random action during validation
valid_sample_cond_info : bool
Expand All @@ -126,6 +131,7 @@ class AlgoConfig:
offline_ratio: float = 0.5
valid_offline_ratio: float = 1
train_random_action_prob: float = 0.0
train_det_after: Optional[int] = None
valid_random_action_prob: float = 0.0
valid_sample_cond_info: bool = True
sampling_tau: float = 0.0
Expand Down
1 change: 1 addition & 0 deletions src/gflownet/algo/graph_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def sample_backward_from_graphs(
def not_done(lst):
return [e for i, e in enumerate(lst) if not done[i]]

# TODO: This should be doable.
if random_action_prob > 0:
raise NotImplementedError("Random action not implemented for backward sampling")

Expand Down
14 changes: 12 additions & 2 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Tuple
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -206,14 +207,23 @@ def create_training_data_from_graphs(
return self.graph_sampler.sample_backward_from_graphs(
graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, dev, random_action_prob
)
trajs = [{"traj": generate_forward_trajectory(i)} for i in graphs]
trajs: List[Dict[str, Any]] = [{"traj": generate_forward_trajectory(i)} for i in graphs]
for traj in trajs:
n_back = [
self.env.count_backward_transitions(gp, check_idempotent=self.cfg.do_correct_idempotent)
for gp, _ in traj["traj"][1:]
] + [1]
traj["bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(self.ctx.device)
traj["result"] = traj["traj"][-1][0]
if self.cfg.do_parameterize_p_b:
traj["bck_a"] = [GraphAction(GraphActionType.Stop)] + [self.env.reverse(g, a) for g, a in traj["traj"]]
# There needs to be an additonal node when we're parameterizing P_B,
# See sampling with parametrized P_B
traj["traj"].append(deepcopy(traj["traj"][-1]))
traj["is_sink"] = [0 for _ in traj["traj"]]
traj["is_sink"][-1] = 1
traj["is_sink"][-2] = 1
assert len(traj["bck_a"]) == len(traj["traj"]) == len(traj["is_sink"])
return trajs

def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: GraphAction, return_aidx: bool = True):
Expand Down
3 changes: 3 additions & 0 deletions src/gflownet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class Config:
The number of training steps after which to validate the model
checkpoint_every : Optional[int]
The number of training steps after which to checkpoint the model
store_all_checkpoints : bool
Whether to store all checkpoints or only the last one
print_every : int
The number of training steps after which to print the training loss
start_at_step : int
Expand All @@ -85,6 +87,7 @@ class Config:
seed: int = 0
validate_every: int = 1000
checkpoint_every: Optional[int] = None
store_all_checkpoints: bool = False
print_every: int = 100
start_at_step: int = 0
num_final_gen_steps: Optional[int] = None
Expand Down
78 changes: 52 additions & 26 deletions src/gflownet/data/qm9.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,88 @@
import sys
import tarfile

import numpy as np
import pandas as pd
import rdkit.Chem as Chem
import torch
from rdkit.Chem import QED, Descriptors
from torch.utils.data import Dataset

from gflownet.utils import sascore


class QM9Dataset(Dataset):
def __init__(self, h5_file=None, xyz_file=None, train=True, target="gap", split_seed=142857, ratio=0.9):
def __init__(self, h5_file=None, xyz_file=None, train=True, targets=["gap"], split_seed=142857, ratio=0.9):
if h5_file is not None:
self.df = pd.HDFStore(h5_file, "r")["df"]

self.hdf = pd.HDFStore(h5_file, "r")
self.df = self.hdf["df"]
self.is_hdf = True
elif xyz_file is not None:
self.load_tar()
self.df = load_tar(xyz_file)
self.is_hdf = False
else:
raise ValueError("Either h5_file or xyz_file must be provided")
rng = np.random.default_rng(split_seed)
idcs = np.arange(len(self.df)) # TODO: error if there is no h5_file provided. Should h5 be required
idcs = np.arange(len(self.df))
rng.shuffle(idcs)
self.target = target
self.targets = targets
if train:
self.idcs = idcs[: int(np.floor(ratio * len(self.df)))]
else:
self.idcs = idcs[int(np.floor(ratio * len(self.df))) :]
self.mol_to_graph = lambda x: x

def get_stats(self, percentile=0.95):
y = self.df[self.target]
return y.min(), y.max(), np.sort(y)[int(y.shape[0] * percentile)]
def setup(self, task, ctx):
self.mol_to_graph = ctx.mol_to_graph

def load_tar(self, xyz_file):
f = tarfile.TarFile(xyz_file, "r")
labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"]
all_mols = []
for pt in f:
pt = f.extractfile(pt)
data = pt.read().decode().splitlines()
all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:])))
self.df = pd.DataFrame(all_mols, columns=["SMILES"] + labels)
def get_stats(self, target=None, percentile=0.95):
if target is None:
target = self.targets[0]
y = self.df[target]
return y.min(), y.max(), np.sort(y)[int(y.shape[0] * percentile)]

def __len__(self):
return len(self.idcs)

def __getitem__(self, idx):
return (
Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]),
torch.tensor([self.df[self.target][self.idcs[idx]]]).float(),
self.mol_to_graph(Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]])),
torch.tensor([self.df[t][self.idcs[idx]] for t in self.targets]).float(),
)

def terminate(self):
if self.is_hdf:
self.hdf.close()

def convert_h5():
# File obtained from
# https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904
# (from http://quantum-machine.org/datasets/)
f = tarfile.TarFile("qm9.xyz.tar", "r")

def load_tar(xyz_file):
labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"]
f = tarfile.TarFile(xyz_file, "r")
all_mols = []
for pt in f:
pt = f.extractfile(pt) # type: ignore
data = pt.read().decode().splitlines() # type: ignore
all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:])))
df = pd.DataFrame(all_mols, columns=["SMILES"] + labels)
store = pd.HDFStore("qm9.h5", "w")
store["df"] = df
mols = df["SMILES"].map(Chem.MolFromSmiles)
df["qed"] = mols.map(QED.qed)
df["sa"] = mols.map(sascore.calculateScore)
df["mw"] = mols.map(Descriptors.MolWt)
return df


def convert_h5(xyz_file="qm9.xyz.tar", h5_file="qm9.h5"):
"""
Convert `xyz_file` and dump it into `h5_file`
"""
# File obtained from
# https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904
# (from http://quantum-machine.org/datasets/)
df = load_tar(xyz_file)
with pd.HDFStore(h5_file, "w") as store:
store["df"] = df


if __name__ == "__main__":
convert_h5(*sys.argv[1:])
26 changes: 22 additions & 4 deletions src/gflownet/data/sampling_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sqlite3
from collections.abc import Iterable
from copy import deepcopy
from typing import Callable, List
from typing import Callable, List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -40,6 +40,7 @@ def __init__(
log_dir: str = None,
sample_cond_info: bool = True,
random_action_prob: float = 0.0,
det_after: Optional[int] = None,
hindsight_ratio: float = 0.0,
init_train_iter: int = 0,
):
Expand Down Expand Up @@ -99,7 +100,8 @@ def __init__(
self.hindsight_ratio = hindsight_ratio
self.train_it = init_train_iter
self.do_validate_batch = False # Turn this on for debugging

self.iter = 0
self.det_after = det_after
# Slightly weird semantics, but if we're sampling x given some fixed cond info (data)
# then "offline" now refers to cond info and online to x, so no duplication and we don't end
# up with 2*batch_size accidentally
Expand All @@ -122,7 +124,10 @@ def _idx_iterator(self):
if self.stream:
# If we're streaming data, just sample `offline_batch_size` indices
while True:
yield self.rng.integers(0, len(self.data), self.offline_batch_size)
if self.offline_batch_size == 0 or len(self.data) == 0:
yield np.arange(0, 0)
else:
yield self.rng.integers(0, len(self.data), self.offline_batch_size)
else:
# Otherwise, figure out which indices correspond to this worker
worker_info = torch.utils.data.get_worker_info()
Expand Down Expand Up @@ -156,6 +161,9 @@ def __len__(self):
return len(self.data)

def __iter__(self):
self.iter += 1
if self.det_after is not None and self.iter > self.det_after:
self.random_action_prob = 0
worker_info = torch.utils.data.get_worker_info()
self._wid = worker_info.id if worker_info is not None else 0
# Now that we know we are in a worker instance, we can initialize per-worker things
Expand All @@ -181,6 +189,7 @@ def __iter__(self):
flat_rewards = (
list(self.task.flat_reward_transform(torch.stack(flat_rewards))) if len(flat_rewards) else []
)

trajs = self.algo.create_training_data_from_graphs(
graphs, self.model, cond_info["encoding"][:num_offline], 0
)
Expand Down Expand Up @@ -236,8 +245,13 @@ def __iter__(self):
log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards)
log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward

assert len(trajs) == num_online + num_offline
# Computes some metrics
extra_info = {}
extra_info = {"random_action_prob": self.random_action_prob}
if num_online > 0:
H = sum(i["fwd_logprob"] for i in trajs[num_offline:])
extra_info["entropy"] = -H / num_online
extra_info["length"] = np.mean([len(i["traj"]) for i in trajs[num_offline:]])
if not self.sample_cond_info:
# If we're using a dataset of preferences, the user may want to know the id of the preference
for i, j in zip(trajs, idcs):
Expand Down Expand Up @@ -316,6 +330,10 @@ def __iter__(self):
batch.preferences = cond_info.get("preferences", None)
batch.focus_dir = cond_info.get("focus_dir", None)
batch.extra_info = extra_info
if self.ctx.has_n():
log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs]
batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32)
batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32)
# TODO: we could very well just pass the cond_info dict to construct_batch above,
# and the algo can decide what it wants to put in the batch object

Expand Down
Loading

0 comments on commit aa15f27

Please sign in to comment.