diff --git a/pyproject.toml b/pyproject.toml index d588c636..58a83288 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ dependencies = [ "pyro-ppl", "gpytorch", "omegaconf>=2.3", + "pandas", # needed for QM9 and HDF5 support. ] [project.optional-dependencies] diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index f2bf178a..6184bdfc 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -29,6 +29,8 @@ class TBConfig: Whether to correct for idempotent actions do_parameterize_p_b : bool Whether to parameterize the P_B distribution (otherwise it is uniform) + do_predict_n : bool + Whether to predict the number of paths in the graph do_length_normalize : bool Whether to normalize the loss by the length of the trajectory subtb_max_len : int @@ -45,6 +47,7 @@ class TBConfig: variant: TBVariant = TBVariant.TB do_correct_idempotent: bool = False do_parameterize_p_b: bool = False + do_predict_n: bool = False do_sample_p_b: bool = False do_length_normalize: bool = False subtb_max_len: int = 128 @@ -109,6 +112,8 @@ class AlgoConfig: Idem but for validation, and `self.test_data`. train_random_action_prob : float The probability of taking a random action during training + train_det_after: Optional[int] + Do not take random actions after this number of steps valid_random_action_prob : float The probability of taking a random action during validation valid_sample_cond_info : bool @@ -126,6 +131,7 @@ class AlgoConfig: offline_ratio: float = 0.5 valid_offline_ratio: float = 1 train_random_action_prob: float = 0.0 + train_det_after: Optional[int] = None valid_random_action_prob: float = 0.0 valid_sample_cond_info: bool = True sampling_tau: float = 0.0 diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 5a6ace57..7ad4fc0a 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -246,6 +246,7 @@ def sample_backward_from_graphs( def not_done(lst): return [e for i, e in enumerate(lst) if not done[i]] + # TODO: This should be doable. if random_action_prob > 0: raise NotImplementedError("Random action not implemented for backward sampling") diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 3a98423f..75e5471f 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -1,4 +1,5 @@ -from typing import Optional, Tuple +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple import networkx as nx import numpy as np @@ -206,7 +207,7 @@ def create_training_data_from_graphs( return self.graph_sampler.sample_backward_from_graphs( graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, dev, random_action_prob ) - trajs = [{"traj": generate_forward_trajectory(i)} for i in graphs] + trajs: List[Dict[str, Any]] = [{"traj": generate_forward_trajectory(i)} for i in graphs] for traj in trajs: n_back = [ self.env.count_backward_transitions(gp, check_idempotent=self.cfg.do_correct_idempotent) @@ -214,6 +215,15 @@ def create_training_data_from_graphs( ] + [1] traj["bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(self.ctx.device) traj["result"] = traj["traj"][-1][0] + if self.cfg.do_parameterize_p_b: + traj["bck_a"] = [GraphAction(GraphActionType.Stop)] + [self.env.reverse(g, a) for g, a in traj["traj"]] + # There needs to be an additonal node when we're parameterizing P_B, + # See sampling with parametrized P_B + traj["traj"].append(deepcopy(traj["traj"][-1])) + traj["is_sink"] = [0 for _ in traj["traj"]] + traj["is_sink"][-1] = 1 + traj["is_sink"][-2] = 1 + assert len(traj["bck_a"]) == len(traj["traj"]) == len(traj["is_sink"]) return trajs def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: GraphAction, return_aidx: bool = True): diff --git a/src/gflownet/config.py b/src/gflownet/config.py index be4fa879..6941e7a7 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -60,6 +60,8 @@ class Config: The number of training steps after which to validate the model checkpoint_every : Optional[int] The number of training steps after which to checkpoint the model + store_all_checkpoints : bool + Whether to store all checkpoints or only the last one print_every : int The number of training steps after which to print the training loss start_at_step : int @@ -85,6 +87,7 @@ class Config: seed: int = 0 validate_every: int = 1000 checkpoint_every: Optional[int] = None + store_all_checkpoints: bool = False print_every: int = 100 start_at_step: int = 0 num_final_gen_steps: Optional[int] = None diff --git a/src/gflownet/data/qm9.py b/src/gflownet/data/qm9.py index b26c29d2..f35bdb14 100644 --- a/src/gflownet/data/qm9.py +++ b/src/gflownet/data/qm9.py @@ -1,62 +1,88 @@ +import sys import tarfile import numpy as np import pandas as pd import rdkit.Chem as Chem import torch +from rdkit.Chem import QED, Descriptors from torch.utils.data import Dataset +from gflownet.utils import sascore + class QM9Dataset(Dataset): - def __init__(self, h5_file=None, xyz_file=None, train=True, target="gap", split_seed=142857, ratio=0.9): + def __init__(self, h5_file=None, xyz_file=None, train=True, targets=["gap"], split_seed=142857, ratio=0.9): if h5_file is not None: - self.df = pd.HDFStore(h5_file, "r")["df"] + + self.hdf = pd.HDFStore(h5_file, "r") + self.df = self.hdf["df"] + self.is_hdf = True elif xyz_file is not None: - self.load_tar() + self.df = load_tar(xyz_file) + self.is_hdf = False + else: + raise ValueError("Either h5_file or xyz_file must be provided") rng = np.random.default_rng(split_seed) - idcs = np.arange(len(self.df)) # TODO: error if there is no h5_file provided. Should h5 be required + idcs = np.arange(len(self.df)) rng.shuffle(idcs) - self.target = target + self.targets = targets if train: self.idcs = idcs[: int(np.floor(ratio * len(self.df)))] else: self.idcs = idcs[int(np.floor(ratio * len(self.df))) :] + self.mol_to_graph = lambda x: x - def get_stats(self, percentile=0.95): - y = self.df[self.target] - return y.min(), y.max(), np.sort(y)[int(y.shape[0] * percentile)] + def setup(self, task, ctx): + self.mol_to_graph = ctx.mol_to_graph - def load_tar(self, xyz_file): - f = tarfile.TarFile(xyz_file, "r") - labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"] - all_mols = [] - for pt in f: - pt = f.extractfile(pt) - data = pt.read().decode().splitlines() - all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:]))) - self.df = pd.DataFrame(all_mols, columns=["SMILES"] + labels) + def get_stats(self, target=None, percentile=0.95): + if target is None: + target = self.targets[0] + y = self.df[target] + return y.min(), y.max(), np.sort(y)[int(y.shape[0] * percentile)] def __len__(self): return len(self.idcs) def __getitem__(self, idx): return ( - Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]), - torch.tensor([self.df[self.target][self.idcs[idx]]]).float(), + self.mol_to_graph(Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]])), + torch.tensor([self.df[t][self.idcs[idx]] for t in self.targets]).float(), ) + def terminate(self): + if self.is_hdf: + self.hdf.close() -def convert_h5(): - # File obtained from - # https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904 - # (from http://quantum-machine.org/datasets/) - f = tarfile.TarFile("qm9.xyz.tar", "r") + +def load_tar(xyz_file): labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"] + f = tarfile.TarFile(xyz_file, "r") all_mols = [] for pt in f: pt = f.extractfile(pt) # type: ignore data = pt.read().decode().splitlines() # type: ignore all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:]))) df = pd.DataFrame(all_mols, columns=["SMILES"] + labels) - store = pd.HDFStore("qm9.h5", "w") - store["df"] = df + mols = df["SMILES"].map(Chem.MolFromSmiles) + df["qed"] = mols.map(QED.qed) + df["sa"] = mols.map(sascore.calculateScore) + df["mw"] = mols.map(Descriptors.MolWt) + return df + + +def convert_h5(xyz_file="qm9.xyz.tar", h5_file="qm9.h5"): + """ + Convert `xyz_file` and dump it into `h5_file` + """ + # File obtained from + # https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904 + # (from http://quantum-machine.org/datasets/) + df = load_tar(xyz_file) + with pd.HDFStore(h5_file, "w") as store: + store["df"] = df + + +if __name__ == "__main__": + convert_h5(*sys.argv[1:]) diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index f14793e4..9795467e 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -2,7 +2,7 @@ import sqlite3 from collections.abc import Iterable from copy import deepcopy -from typing import Callable, List +from typing import Callable, List, Optional import numpy as np import torch @@ -40,6 +40,7 @@ def __init__( log_dir: str = None, sample_cond_info: bool = True, random_action_prob: float = 0.0, + det_after: Optional[int] = None, hindsight_ratio: float = 0.0, init_train_iter: int = 0, ): @@ -99,7 +100,8 @@ def __init__( self.hindsight_ratio = hindsight_ratio self.train_it = init_train_iter self.do_validate_batch = False # Turn this on for debugging - + self.iter = 0 + self.det_after = det_after # Slightly weird semantics, but if we're sampling x given some fixed cond info (data) # then "offline" now refers to cond info and online to x, so no duplication and we don't end # up with 2*batch_size accidentally @@ -122,7 +124,10 @@ def _idx_iterator(self): if self.stream: # If we're streaming data, just sample `offline_batch_size` indices while True: - yield self.rng.integers(0, len(self.data), self.offline_batch_size) + if self.offline_batch_size == 0 or len(self.data) == 0: + yield np.arange(0, 0) + else: + yield self.rng.integers(0, len(self.data), self.offline_batch_size) else: # Otherwise, figure out which indices correspond to this worker worker_info = torch.utils.data.get_worker_info() @@ -156,6 +161,9 @@ def __len__(self): return len(self.data) def __iter__(self): + self.iter += 1 + if self.det_after is not None and self.iter > self.det_after: + self.random_action_prob = 0 worker_info = torch.utils.data.get_worker_info() self._wid = worker_info.id if worker_info is not None else 0 # Now that we know we are in a worker instance, we can initialize per-worker things @@ -181,6 +189,7 @@ def __iter__(self): flat_rewards = ( list(self.task.flat_reward_transform(torch.stack(flat_rewards))) if len(flat_rewards) else [] ) + trajs = self.algo.create_training_data_from_graphs( graphs, self.model, cond_info["encoding"][:num_offline], 0 ) @@ -236,8 +245,13 @@ def __iter__(self): log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward + assert len(trajs) == num_online + num_offline # Computes some metrics - extra_info = {} + extra_info = {"random_action_prob": self.random_action_prob} + if num_online > 0: + H = sum(i["fwd_logprob"] for i in trajs[num_offline:]) + extra_info["entropy"] = -H / num_online + extra_info["length"] = np.mean([len(i["traj"]) for i in trajs[num_offline:]]) if not self.sample_cond_info: # If we're using a dataset of preferences, the user may want to know the id of the preference for i, j in zip(trajs, idcs): @@ -316,6 +330,10 @@ def __iter__(self): batch.preferences = cond_info.get("preferences", None) batch.focus_dir = cond_info.get("focus_dir", None) batch.extra_info = extra_info + if self.ctx.has_n(): + log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] + batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) + batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) # TODO: we could very well just pass the cond_info dict to construct_batch above, # and the algo can decide what it wants to put in the batch object diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index ac118441..bab9506b 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -1,10 +1,13 @@ from collections import defaultdict +from math import log from typing import List, Tuple +import networkx as nx import numpy as np import rdkit.Chem as Chem import torch import torch_geometric.data as gd +from scipy import special from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionType, GraphBuildingEnvContext from gflownet.models import bengio2021flow @@ -85,6 +88,7 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu GraphActionType.RemoveEdgeAttr, ] self.device = torch.device("cpu") + self.n_counter = NCounter() self.sorted_frags = sorted(list(enumerate(self.frags_mol)), key=lambda x: -x[1].GetNumAtoms()) def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True): @@ -355,6 +359,86 @@ def object_to_log_repr(self, g: Graph): """Convert a Graph to a string representation""" return Chem.MolToSmiles(self.graph_to_mol(g)) + def has_n(self) -> bool: + return True + + def log_n(self, g: Graph) -> int: + return self.n_counter(g) + + +class NCounter: + """ + Dynamic program to calculate the number of trajectories to a state. + See Appendix D of "Maximum entropy GFlowNets with soft Q-learning" + by Mohammadpour et al 2024 (https://arxiv.org/abs/2312.14331) for a proof. + """ + + def __init__(self): + # Hold the log factorial + self.cache = [0.0, 0.0] + + def lfac(self, arg: int): + while arg >= len(self.cache): + self.cache.append(log(len(self.cache)) + self.cache[-1]) + return self.cache[arg] + + def lcomb(self, x, y): + # log c(x, y) = log (x! / (y! (x - y)!)) + assert x >= y + return self.lfac(x) - self.lfac(y) - self.lfac(x - y) + + @staticmethod + def root_tree(og: nx.Graph, x): + g = nx.DiGraph(nx.create_empty_copy(og)) + visited = np.zeros(len(g), bool) + visited[x] = True + q = [x] + while len(q) > 0: # print(i, x) + x = q.pop() + for i in nx.neighbors(og, x): + if not visited[i]: + visited[i] = True + g.add_edge(x, i, **(og.get_edge_data(x, i) | og.get_edge_data(i, x))) + q.append(i) + + return g + + def f(self, g, x): + elem = np.full((len(g),), -1, int) + ways = np.full((len(g),), -1, float) + + def _f(x): + if elem[x] < 0: + e, w = 0, 0 + for i in nx.neighbors(g, x): + e1, w1 = _f(i) + # edge feature + f = len(g.get_edge_data(x, i)) + for i in range(f): + w1 += np.log(e1 + i) + e1 += f + + w = w + w1 + self.lcomb(e + e1, e) + e = e + e1 + + elem[x] = e + 1 + ways[x] = w + return elem[x], ways[x] + + return _f(x)[1] + + def __call__(self, g): + if len(g) == 0: + return 0 + + acc = [] + for i in nx.nodes(g): + rg = self.root_tree(g, i) + x = self.f(rg, i) + acc.append(x) + + return special.logsumexp(acc) + def _recursive_decompose(ctx, m, all_matches, a2f, frags, bonds, max_depth=9, numiters=None): if numiters is None: diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 74ba3e4f..f1b12d93 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -550,6 +550,7 @@ def __init__( slice_dict[k].to(dev) if k is not None else torch.arange(graphs.num_graphs + 1, device=dev) for k in keys ] self.logprobs = None + self.log_n = None if deduplicate_edge_index and "edge_index" in keys: for idx, k in enumerate(keys): @@ -563,6 +564,8 @@ def detach(self): new.logits = [i.detach() for i in new.logits] if new.logprobs is not None: new.logprobs = [i.detach() for i in new.logprobs] + if new.log_n is not None: + new.log_n = new.log_n.detach() return new def to(self, device): @@ -572,10 +575,28 @@ def to(self, device): self.slice = [i.to(device) for i in self.slice] if self.logprobs is not None: self.logprobs = [i.to(device) for i in self.logprobs] + if self.log_n is not None: + self.log_n = self.log_n.to(device) if self.masks is not None: self.masks = [i.to(device) for i in self.masks] return self + def log_n_actions(self): + if self.log_n is None: + self.log_n = ( + sum( + [ + scatter(m.broadcast_to(i.shape).int().sum(1), b, dim=0, dim_size=self.num_graphs, reduce="sum") + for m, i, b in zip(self.masks, self.logits, self.batch) + ] + ) + .clamp(1) + .float() + .log() + .clamp(1) + ) + return self.log_n + def _compute_batchwise_max( self, x: List[torch.Tensor], @@ -674,8 +695,25 @@ def sample(self) -> List[Tuple[int, int, int]]: u = [torch.rand(i.shape, device=self.dev) for i in self.logits] # Gumbel noise gumbel = [logit - (-noise.log()).log() for logit, noise in zip(self.logits, u)] + + if self.masks is not None: + gumbel_safe = [ + torch.where( + mask == 1, + torch.maximum( + x, + torch.nextafter( + torch.tensor(torch.finfo(x.dtype).min, dtype=x.dtype), torch.tensor(0.0, dtype=x.dtype) + ).to(x.device), + ), + torch.finfo(x.dtype).min, + ) + for x, mask in zip(gumbel, self.masks) + ] + else: + gumbel_safe = gumbel # Take the argmax - return self.argmax(x=gumbel) + return self.argmax(x=gumbel_safe) def argmax( self, @@ -922,3 +960,12 @@ def object_to_log_repr(self, g: Graph) -> str: return json.dumps( [[(i, g.nodes[i]) for i in g.nodes], [(e, g.edges[e]) for e in g.edges]], separators=(",", ":") ) + + def has_n(self) -> bool: + return False + + def log_n(self, g) -> float: + return 0.0 + + def traj_log_n(self, traj): + return [self.log_n(g) for g, _ in traj] diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 8c9c0b5d..20c05586 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -257,17 +257,17 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int def graph_to_Data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance""" - x = np.zeros((max(1, len(g.nodes)), self.num_node_dim - self.num_rw_feat)) + x = np.zeros((max(1, len(g.nodes)), self.num_node_dim - self.num_rw_feat), dtype=np.float32) x[0, -1] = len(g.nodes) == 0 - add_node_mask = np.ones((x.shape[0], self.num_new_node_values)) + add_node_mask = np.ones((x.shape[0], self.num_new_node_values), dtype=np.float32) if self.max_nodes is not None and len(g.nodes) >= self.max_nodes: add_node_mask *= 0 - remove_node_mask = np.zeros((x.shape[0], 1)) + (1 if len(g) == 0 else 0) - remove_node_attr_mask = np.zeros((x.shape[0], len(self.settable_atom_attrs))) + remove_node_mask = np.zeros((x.shape[0], 1), dtype=np.float32) + (1 if len(g) == 0 else 0) + remove_node_attr_mask = np.zeros((x.shape[0], len(self.settable_atom_attrs)), dtype=np.float32) explicit_valence = {} max_valence = {} - set_node_attr_mask = np.ones((x.shape[0], self.num_node_attr_logits)) + set_node_attr_mask = np.ones((x.shape[0], self.num_node_attr_logits), dtype=np.float32) bridges = set(nx.bridges(g)) if not len(g.nodes): set_node_attr_mask *= 0 @@ -326,14 +326,14 @@ def graph_to_Data(self, g: Graph) -> gd.Data: s, e = self.atom_attr_logit_slice["expl_H"] set_node_attr_mask[i, s:e] = 0 - remove_edge_mask = np.zeros((len(g.edges), 1)) + remove_edge_mask = np.zeros((len(g.edges), 1), dtype=np.float32) for i, e in enumerate(g.edges): if e not in bridges: remove_edge_mask[i] = 1 - edge_attr = np.zeros((len(g.edges) * 2, self.num_edge_dim)) - set_edge_attr_mask = np.zeros((len(g.edges), self.num_edge_attr_logits)) - remove_edge_attr_mask = np.zeros((len(g.edges), len(self.bond_attrs))) + edge_attr = np.zeros((len(g.edges) * 2, self.num_edge_dim), dtype=np.float32) + set_edge_attr_mask = np.zeros((len(g.edges), self.num_edge_attr_logits), dtype=np.float32) + remove_edge_attr_mask = np.zeros((len(g.edges), len(self.bond_attrs)), dtype=np.float32) for i, e in enumerate(g.edges): ad = g.edges[e] for k, sl in zip(self.bond_attrs, self.bond_attr_slice): @@ -368,17 +368,21 @@ def graph_to_Data(self, g: Graph) -> gd.Data: and explicit_valence[u] + 1 <= max_valence[u] and explicit_valence[v] + 1 <= max_valence[v] ) - ] + ], + dtype=np.float32, ) data = dict( x=x, edge_index=edge_index, edge_attr=edge_attr, non_edge_index=non_edge_index.astype(np.int64).reshape((-1, 2)).T, - stop_mask=np.ones((1, 1)) * (len(g.nodes) > 0), # Can only stop if there's at least a node + stop_mask=np.ones((1, 1), dtype=np.float32) + * (len(g.nodes) > 0), # Can only stop if there's at least a node add_node_mask=add_node_mask, set_node_attr_mask=set_node_attr_mask, - add_edge_mask=np.ones((non_edge_index.shape[0], 1)), # Already filtered by checking for valence + add_edge_mask=np.ones( + (non_edge_index.shape[0], 1), dtype=np.float32 + ), # Already filtered by checking for valence set_edge_attr_mask=set_edge_attr_mask, remove_node_mask=remove_node_mask, remove_node_attr_mask=remove_node_attr_mask, diff --git a/src/gflownet/models/bengio2021flow.py b/src/gflownet/models/bengio2021flow.py index d975dfc9..ae71d74d 100644 --- a/src/gflownet/models/bengio2021flow.py +++ b/src/gflownet/models/bengio2021flow.py @@ -106,6 +106,30 @@ ["c1ncc2nc[nH]c2n1", [2, 6]], ] +# 18 fragments from "Towards Understanding and Improving GFlowNet Training" +# by Shen et al. (https://arxiv.org/abs/2305.07170) + +FRAGMENTS_18 = [ + ["CO", [1, 0]], + ["O=c1[nH]cnc2[nH]cnc12", [3, 6]], + ["S", [0, 0]], + ["C1CNCCN1", [2, 5]], + ["c1cc[nH+]cc1", [3, 1]], + ["c1ccccc1", [0, 2]], + ["C1CCCCC1", [0, 2]], + ["CC(C)C", [1, 2]], + ["C1CCOCC1", [0, 2]], + ["c1cn[nH]c1", [4, 0]], + ["C1CCNC1", [2, 0]], + ["c1cncnc1", [0, 1]], + ["O=c1nc2[nH]c3ccccc3nc-2c(=O)[nH]1", [8, 4]], + ["c1ccncc1", [1, 0]], + ["O=c1nccc[nH]1", [6, 3]], + ["O=c1cc[nH]c(=O)[nH]1", [2, 4]], + ["C1CCOC1", [2, 4]], + ["C1CCNCC1", [1, 0]], +] + class MPNNet(nn.Module): def __init__( diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 98791be5..c815cce9 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -17,12 +17,21 @@ from .trainer import GFNTrainer +def model_grad_norm(model): + x = 0 + for i in model.parameters(): + if i.grad is not None: + x += (i.grad * i.grad).sum() + return torch.sqrt(x) + + class StandardOnlineTrainer(GFNTrainer): def setup_model(self): self.model = GraphTransformerGFN( self.ctx, self.cfg, do_bck=self.cfg.algo.tb.do_parameterize_p_b, + num_graph_out=self.cfg.algo.tb.do_predict_n + 1, ) def setup_algo(self): @@ -43,6 +52,22 @@ def setup_data(self): self.training_data = [] self.test_data = [] + def _opt(self, params, lr=None, momentum=None): + if lr is None: + lr = self.cfg.opt.learning_rate + if momentum is None: + momentum = self.cfg.opt.momentum + if self.cfg.opt.opt == "adam": + return torch.optim.Adam( + params, + lr, + (momentum, 0.999), + weight_decay=self.cfg.opt.weight_decay, + eps=self.cfg.opt.adam_eps, + ) + + raise NotImplementedError(f"{self.cfg.opt.opt} is not implemented") + def setup(self): super().setup() self.offline_ratio = 0 @@ -55,14 +80,8 @@ def setup(self): else: Z_params = [] non_Z_params = list(self.model.parameters()) - self.opt = torch.optim.Adam( - non_Z_params, - self.cfg.opt.learning_rate, - (self.cfg.opt.momentum, 0.999), - weight_decay=self.cfg.opt.weight_decay, - eps=self.cfg.opt.adam_eps, - ) - self.opt_Z = torch.optim.Adam(Z_params, self.cfg.algo.tb.Z_learning_rate, (0.9, 0.999)) + self.opt = self._opt(non_Z_params) + self.opt_Z = self._opt(Z_params, self.cfg.algo.tb.Z_learning_rate, 0.9) self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) @@ -77,7 +96,8 @@ def setup(self): self.mb_size = self.cfg.algo.global_batch_size self.clip_grad_callback = { "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.cfg.opt.clip_grad_param), - "norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.cfg.opt.clip_grad_param), + "norm": lambda params: [torch.nn.utils.clip_grad_norm_(p, self.cfg.opt.clip_grad_param) for p in params], + "total_norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.cfg.opt.clip_grad_param), "none": lambda x: None, }[self.cfg.opt.clip_grad_type] @@ -85,17 +105,20 @@ def setup(self): git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] self.cfg.git_hash = git_hash - os.makedirs(self.cfg.log_dir, exist_ok=True) - print("\n\nHyperparameters:\n") yaml = OmegaConf.to_yaml(self.cfg) - print(yaml) - with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w") as f: + os.makedirs(self.cfg.log_dir, exist_ok=True) + if self.print_hps: + print("\n\nHyperparameters:\n") + print(yaml) + with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w", encoding="utf8") as f: f.write(yaml) def step(self, loss: Tensor): loss.backward() - for i in self.model.parameters(): - self.clip_grad_callback(i) + with torch.no_grad(): + g0 = model_grad_norm(self.model) + self.clip_grad_callback(self.model.parameters()) + g1 = model_grad_norm(self.model) self.opt.step() self.opt.zero_grad() self.opt_Z.step() @@ -105,3 +128,4 @@ def step(self, loss: Tensor): if self.sampling_tau > 0: for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) + return {"grad_norm": g0, "grad_norm_clip": g1} diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index 28960399..7e7df30d 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -4,7 +4,7 @@ @dataclass class SEHTaskConfig: - pass # SEH just uses a temperature conditional + reduced_frag: bool = False @dataclass @@ -14,16 +14,19 @@ class SEHMOOTaskConfig: Attributes ---------- n_valid : int - The number of valid cond_info tensors to sample + The number of valid cond_info tensors to sample. n_valid_repeats : int - The number of times to repeat the valid cond_info tensors + The number of times to repeat the valid cond_info tensors. objectives : List[str] - The objectives to use for the multi-objective optimization. Should be a subset of ["seh", "qed", "sa", "wt"]. + The objectives to use for the multi-objective optimization. Should be a subset of ["seh", "qed", "sa", "mw"]. + online_pareto_front : bool + Whether to calculate the pareto front online. """ n_valid: int = 15 n_valid_repeats: int = 128 objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"]) + online_pareto_front: bool = True @dataclass @@ -32,8 +35,33 @@ class QM9TaskConfig: model_path: str = "./data/qm9/qm9_model.pt" +@dataclass +class QM9MOOTaskConfig: + """ + Config for the QM9MooTask + + Attributes + ---------- + n_valid : int + The number of valid cond_info tensors to sample. + n_valid_repeats : int + The number of times to repeat the valid cond_info tensors. + objectives : List[str] + The objectives to use for the multi-objective optimization. Should be a subset of ["gap", "qed", "sa", "mw"]. + While "mw" can be used, it is not recommended as the molecules are already small. + online_pareto_front : bool + Whether to calculate the pareto front online. + """ + + n_valid: int = 15 + n_valid_repeats: int = 128 + objectives: List[str] = field(default_factory=lambda: ["gap", "qed", "sa"]) + online_pareto_front: bool = True + + @dataclass class TasksConfig: qm9: QM9TaskConfig = QM9TaskConfig() + qm9_moo: QM9MOOTaskConfig = QM9MOOTaskConfig() seh: SEHTaskConfig = SEHTaskConfig() seh_moo: SEHMOOTaskConfig = SEHMOOTaskConfig() diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index 866a7fac..d66f571a 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -1,4 +1,3 @@ -import os from typing import Callable, Dict, List, Tuple, Union import numpy as np @@ -6,7 +5,6 @@ import torch.nn as nn import torch_geometric.data as gd from rdkit.Chem.rdchem import Mol as RDMol -from ruamel.yaml import YAML from torch import Tensor from torch.utils.data import Dataset @@ -32,11 +30,12 @@ def __init__( ): self._wrap_model = wrap_model self.rng = rng - self.models = self.load_task_models(cfg.task.qm9.model_path) + self.models = self.load_task_models(cfg.task.qm9.model_path, torch.device(cfg.device)) self.dataset = dataset self.temperature_conditional = TemperatureConditional(cfg, rng) + self.num_cond_dim = self.temperature_conditional.encoding_size() # TODO: fix interface - self._min, self._max, self._percentile_95 = self.dataset.get_stats(percentile=0.05) # type: ignore + self._min, self._max, self._percentile_95 = self.dataset.get_stats("gap", percentile=0.05) # type: ignore self._width = self._max - self._min self._rtrans = "unit+95p" # TODO: hyperparameter @@ -61,7 +60,7 @@ def inverse_flat_reward_transform(self, rp): elif self._rtrans == "unit+95p": return (1 - rp + (1 - self._percentile_95)) * self._width + self._min - def load_task_models(self, path): + def load_task_models(self, path, device): gap_model = mxmnet.MXMNet(mxmnet.Config(128, 6, 5.0)) # TODO: this path should be part of the config? try: @@ -74,7 +73,7 @@ def load_task_models(self, path): "https://storage.googleapis.com/emmanuel-data/models/mxmnet_gap_model.pt", ) gap_model.load_state_dict(state_dict) - gap_model.cuda() + gap_model.to(device) gap_model, self.device = self._wrap_model(gap_model, send_to_device=True) return {"mxmnet_gap": gap_model} @@ -84,16 +83,28 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) + def compute_reward_from_graph(self, graphs: List[gd.Data]) -> Tensor: + batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) + batch.to(self.device) + preds = self.models["mxmnet_gap"](batch).reshape((-1,)).data.cpu() / mxmnet.HAR2EV # type: ignore[attr-defined] + preds[preds.isnan()] = 1 + preds = ( + self.flat_reward_transform(preds) + .clip(1e-4, 2) + .reshape( + -1, + ) + ) + return preds + def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] is_valid = torch.tensor([i is not None for i in graphs]).bool() if not is_valid.any(): return FlatRewards(torch.zeros((0, 1))), is_valid - batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) - batch.to(self.device) - preds = self.models["mxmnet_gap"](batch).reshape((-1,)).data.cpu() / mxmnet.HAR2EV # type: ignore[attr-defined] - preds[preds.isnan()] = 1 - preds = self.flat_reward_transform(preds).clip(1e-4, 2).reshape((-1, 1)) + + preds = self.compute_reward_from_graph(graphs).reshape((-1, 1)) + assert len(preds) == is_valid.sum() return FlatRewards(preds), is_valid @@ -121,7 +132,10 @@ def set_default_hps(self, cfg: Config): def setup_env_context(self): self.ctx = MolBuildingEnvContext( - ["C", "N", "F", "O"], expl_H_range=[0, 1, 2, 3], num_cond_dim=32, allow_5_valence_nitrogen=True + ["C", "N", "F", "O"], + expl_H_range=[0, 1, 2, 3], + num_cond_dim=self.task.num_cond_dim, + allow_5_valence_nitrogen=True, ) # Note: we only need the allow_5_valence_nitrogen flag because of how we generate trajectories # from the dataset. For example, consider tue Nitrogen atom in this: C[NH+](C)C, when s=CN(C)C, if the action @@ -131,8 +145,10 @@ def setup_env_context(self): # (PR #98) this edge case is the only case where the ordering in which attributes are set can matter. def setup_data(self): - self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, target="gap") - self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, target="gap") + self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, targets=["gap"]) + self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, targets=["gap"]) + self.to_terminate.append(self.training_data.terminate) + self.to_terminate.append(self.test_data.terminate) def setup_task(self): self.task = QM9GapTask( @@ -142,16 +158,7 @@ def setup_task(self): wrap_model=self._wrap_for_mp, ) - -def main(): - """Example of how this model can be run.""" - yaml = YAML(typ="safe", pure=True) - config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "qm9.yaml") - with open(config_file, "r") as f: - hps = yaml.load(f) - trial = QM9GapTrainer(hps) - trial.run() - - -if __name__ == "__main__": - main() + def setup(self): + super().setup() + self.training_data.setup(self.task, self.ctx) + self.test_data.setup(self.task, self.ctx) diff --git a/src/gflownet/tasks/qm9/qm9.yaml b/src/gflownet/tasks/qm9/qm9.yaml deleted file mode 100644 index 19701fac..00000000 --- a/src/gflownet/tasks/qm9/qm9.yaml +++ /dev/null @@ -1,10 +0,0 @@ -opt: - lr_decay: 10000 -task: - qm9: - h5_path: /rxrx/data/chem/qm9/qm9.h5 - model_path: /rxrx/data/chem/qm9/mxmnet_gap_model.pt -num_training_steps: 100000 -validate_every: 100 -log_dir: ./logs/debug_qm9 -num_workers: 0 diff --git a/src/gflownet/tasks/qm9/qm9_moo.py b/src/gflownet/tasks/qm9/qm9_moo.py new file mode 100644 index 00000000..cb0e8277 --- /dev/null +++ b/src/gflownet/tasks/qm9/qm9_moo.py @@ -0,0 +1,335 @@ +import pathlib +from typing import Any, Callable, Dict, List, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch_geometric.data as gd +from rdkit.Chem.rdchem import Mol as RDMol +from torch import Tensor +from torch.utils.data import Dataset + +import gflownet.models.mxmnet as mxmnet +from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL +from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce +from gflownet.config import Config +from gflownet.data.qm9 import QM9Dataset +from gflownet.envs.mol_building_env import MolBuildingEnvContext +from gflownet.tasks.qm9.qm9 import QM9GapTask, QM9GapTrainer +from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, aux_tasks +from gflownet.trainer import FlatRewards, RewardScalar +from gflownet.utils import metrics +from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences +from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook +from gflownet.utils.transforms import to_logreward + + +class QM9GapMOOTask(QM9GapTask): + """Sets up a multiobjective task where the rewards are (functions of): + - the homo-lumo gap, + - its QED, + - its synthetic accessibility, + - and its molecular weight. + + The proxy is pretrained, and obtained from the original GFlowNet paper, see `gflownet.models.bengio2021flow`. + """ + + def __init__( + self, + dataset: Dataset, + cfg: Config, + rng: np.random.Generator = None, + wrap_model: Callable[[nn.Module], nn.Module] = None, + ): + super().__init__(dataset, cfg, rng, wrap_model) + self.cfg = cfg + mcfg = self.cfg.task.qm9_moo + self.objectives = cfg.task.qm9_moo.objectives + self.dataset = dataset + if self.cfg.cond.focus_region.focus_type is not None: + self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid, rng) + else: + self.focus_cond = None + self.pref_cond = MultiObjectiveWeightedPreferences(self.cfg) + self.temperature_sample_dist = cfg.cond.temperature.sample_dist + self.temperature_dist_params = cfg.cond.temperature.dist_params + self.num_thermometer_dim = cfg.cond.temperature.num_thermometer_dim + self.num_cond_dim = ( + self.temperature_conditional.encoding_size() + + self.pref_cond.encoding_size() + + (self.focus_cond.encoding_size() if self.focus_cond is not None else 0) + ) + assert set(self.objectives) <= {"gap", "qed", "sa", "mw"} and len(self.objectives) == len(set(self.objectives)) + + def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: + return FlatRewards(torch.as_tensor(y)) + + def inverse_flat_reward_transform(self, rp): + return rp + + def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + cond_info = super().sample_conditional_information(n, train_it) + pref_ci = self.pref_cond.sample(n) + focus_ci = ( + self.focus_cond.sample(n, train_it) if self.focus_cond is not None else {"encoding": torch.zeros(n, 0)} + ) + cond_info = { + **cond_info, + **pref_ci, + **focus_ci, + "encoding": torch.cat([cond_info["encoding"], pref_ci["encoding"], focus_ci["encoding"]], dim=1), + } + return cond_info + + def encode_conditional_information(self, steer_info: Tensor) -> Dict[str, Tensor]: + """ + Encode conditional information at validation-time + We use the maximum temperature beta for inference + Args: + steer_info: Tensor of shape (Batch, 2 * n_objectives) containing the preferences and focus_dirs + in that order + Returns: + Dict[str, Tensor]: Dictionary containing the encoded conditional information + """ + n = len(steer_info) + if self.temperature_sample_dist == "constant": + beta = torch.ones(n) * self.temperature_dist_params[0] + beta_enc = torch.zeros((n, self.num_thermometer_dim)) + else: + beta = torch.ones(n) * self.temperature_dist_params[-1] + beta_enc = torch.ones((n, self.num_thermometer_dim)) + + assert len(beta.shape) == 1, f"beta should be of shape (Batch,), got: {beta.shape}" + + # TODO: positional assumption here, should have something cleaner + preferences = steer_info[:, : len(self.objectives)].float() + focus_dir = steer_info[:, len(self.objectives) :].float() + + preferences_enc = self.pref_cond.encode(preferences) + if self.focus_cond is not None: + focus_enc = self.focus_cond.encode(focus_dir) + encoding = torch.cat([beta_enc, preferences_enc, focus_enc], 1).float() + else: + encoding = torch.cat([beta_enc, preferences_enc], 1).float() + return { + "beta": beta, + "encoding": encoding, + "preferences": preferences, + "focus_dir": focus_dir, + } + + def relabel_condinfo_and_logrewards( + self, cond_info: Dict[str, Tensor], log_rewards: Tensor, flat_rewards: FlatRewards, hindsight_idxs: Tensor + ): + # TODO: we seem to be relabeling tensors in place, could that cause a problem? + if self.focus_cond is None: + raise NotImplementedError("Hindsight relabeling only implemented for focus conditioning") + if self.focus_cond.cfg.focus_type is None: + return cond_info, log_rewards + # only keep hindsight_idxs that actually correspond to a violated constraint + _, in_focus_mask = metrics.compute_focus_coef( + flat_rewards, cond_info["focus_dir"], self.focus_cond.cfg.focus_cosim + ) + out_focus_mask = torch.logical_not(in_focus_mask) + hindsight_idxs = hindsight_idxs[out_focus_mask[hindsight_idxs]] + + # relabels the focus_dirs and log_rewards + cond_info["focus_dir"][hindsight_idxs] = nn.functional.normalize(flat_rewards[hindsight_idxs], dim=1) + + preferences_enc = self.pref_cond.encode(cond_info["preferences"]) + focus_enc = self.focus_cond.encode(cond_info["focus_dir"]) + cond_info["encoding"] = torch.cat( + [cond_info["encoding"][:, : self.num_thermometer_dim], preferences_enc, focus_enc], 1 + ) + + log_rewards = self.cond_info_to_logreward(cond_info, flat_rewards) + return cond_info, log_rewards + + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + """ + Compute the logreward from the flat_reward and the conditional information + """ + if isinstance(flat_reward, list): + if isinstance(flat_reward[0], Tensor): + flat_reward = torch.stack(flat_reward) + else: + flat_reward = torch.tensor(flat_reward) + + scalarized_rewards = self.pref_cond.transform(cond_info, flat_reward) + scalarized_logrewards = to_logreward(scalarized_rewards) + focused_logreward = ( + self.focus_cond.transform(cond_info, flat_reward, scalarized_logrewards) + if self.focus_cond is not None + else scalarized_logrewards + ) + tempered_logreward = self.temperature_conditional.transform(cond_info, focused_logreward) + clamped_logreward = tempered_logreward.clamp(min=self.cfg.algo.illegal_action_logreward) + + return RewardScalar(clamped_logreward) + + def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: + graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] + assert len(graphs) == len(mols) + is_valid = [i is not None for i in graphs] + is_valid_t = torch.tensor(is_valid, dtype=torch.bool) + + if not any(is_valid): + return FlatRewards(torch.zeros((0, len(self.objectives)))), is_valid_t + else: + flat_r: List[Tensor] = [] + for obj in self.objectives: + if obj == "gap": + flat_r.append(super().compute_reward_from_graph(graphs)) + else: + flat_r.append(aux_tasks[obj](mols, is_valid)) + + flat_rewards = torch.stack(flat_r, dim=1) + assert flat_rewards.shape[0] == is_valid_t.sum() + return FlatRewards(flat_rewards), is_valid_t + + +class QM9MOOTrainer(QM9GapTrainer): + task: QM9GapMOOTask + ctx: MolBuildingEnvContext + + def set_default_hps(self, cfg: Config): + super().set_default_hps(cfg) + cfg.algo.sampling_tau = 0.95 + # We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info) + # sampling and set the offline ratio to 1 + cfg.algo.valid_sample_cond_info = False + cfg.algo.valid_offline_ratio = 1 + + def setup_algo(self): + algo = self.cfg.algo.method + if algo == "MOREINFORCE": + self.algo = MultiObjectiveReinforce(self.env, self.ctx, self.rng, self.cfg) + elif algo == "MOQL": + self.algo = EnvelopeQLearning(self.env, self.ctx, self.task, self.rng, self.cfg) + else: + super().setup_algo() + + def setup_task(self): + self.task = QM9GapMOOTask( + dataset=self.training_data, + cfg=self.cfg, + rng=self.rng, + wrap_model=self._wrap_for_mp, + ) + + def setup_model(self): + if self.cfg.algo.method == "MOQL": + self.model = GraphTransformerFragEnvelopeQL( + self.ctx, + num_emb=self.cfg.model.num_emb, + num_layers=self.cfg.model.num_layers, + num_heads=self.cfg.model.graph_transformer.num_heads, + num_objectives=len(self.cfg.task.seh_moo.objectives), + ) + else: + super().setup_model() + + def setup(self): + super().setup() + if self.cfg.task.seh_moo.online_pareto_front: + self.sampling_hooks.append( + MultiObjectiveStatsHook( + 256, + self.cfg.log_dir, + compute_igd=True, + compute_pc_entropy=True, + compute_focus_accuracy=True if self.cfg.cond.focus_region.focus_type is not None else False, + focus_cosim=self.cfg.cond.focus_region.focus_cosim, + ) + ) + self.to_terminate.append(self.sampling_hooks[-1].terminate) + # instantiate preference and focus conditioning vectors for validation + + n_obj = len(self.cfg.task.seh_moo.objectives) + cond_cfg = self.cfg.cond + + # making sure hyperparameters for preferences and focus regions are consistent + if not ( + cond_cfg.focus_region.focus_type is None + or cond_cfg.focus_region.focus_type == "centered" + or (isinstance(cond_cfg.focus_region.focus_type, list) and len(cond_cfg.focus_region.focus_type) == 1) + ): + assert cond_cfg.weighted_prefs.preference_type is None, ( + f"Cannot use preferences with multiple focus regions, " + f"here focus_type={cond_cfg.focus_region.focus_type} " + f"and preference_type={cond_cfg.weighted_prefs.preference_type }" + ) + + if isinstance(cond_cfg.focus_region.focus_type, list) and len(cond_cfg.focus_region.focus_type) > 1: + n_valid = len(cond_cfg.focus_region.focus_type) + else: + n_valid = self.cfg.task.seh_moo.n_valid + + # preference vectors + if cond_cfg.weighted_prefs.preference_type is None: + valid_preferences = np.ones((n_valid, n_obj)) + elif cond_cfg.weighted_prefs.preference_type == "dirichlet": + valid_preferences = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l1") + elif cond_cfg.weighted_prefs.preference_type == "seeded_single": + seeded_prefs = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) + valid_preferences = seeded_prefs[0].reshape((1, n_obj)) + self.task.seeded_preference = valid_preferences[0] + elif cond_cfg.weighted_prefs.preference_type == "seeded_many": + valid_preferences = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) + else: + raise NotImplementedError(f"Unknown preference type {cond_cfg.weighted_prefs.preference_type}") + + # TODO: this was previously reported, would be nice to serialize it + # hps["fixed_focus_dirs"] = ( + # np.unique(self.task.fixed_focus_dirs, axis=0).tolist() if self.task.fixed_focus_dirs is not None else None + # ) + if self.task.focus_cond is not None: + assert self.task.focus_cond.valid_focus_dirs.shape == ( + n_valid, + n_obj, + ), ( + "Invalid shape for valid_preferences, " + f"{self.task.focus_cond.valid_focus_dirs.shape} != ({n_valid}, {n_obj})" + ) + + # combine preferences and focus directions (fixed focus cosim) since they could be used together + # (not either/or). TODO: this relies on positional assumptions, should have something cleaner + valid_cond_vector = np.concatenate([valid_preferences, self.task.focus_cond.valid_focus_dirs], axis=1) + else: + valid_cond_vector = valid_preferences + + self._top_k_hook = TopKHook(10, self.cfg.task.seh_moo.n_valid_repeats, n_valid) + self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=self.cfg.task.seh_moo.n_valid_repeats) + self.valid_sampling_hooks.append(self._top_k_hook) + + self.algo.task = self.task + + def setup_data(self): + self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, targets=self.cfg.task.qm9_moo.objectives) + self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, targets=self.cfg.task.qm9_moo.objectives) + self.to_terminate.append(self.training_data.terminate) + self.to_terminate.append(self.test_data.terminate) + + def build_callbacks(self): + # We use this class-based setup to be compatible with the DeterminedAI API, but no direct + # dependency is required. + parent = self + + class TopKMetricCB: + def on_validation_end(self, metrics: Dict[str, Any]): + top_k = parent._top_k_hook.finalize() + for i in range(len(top_k)): + metrics[f"topk_rewards_{i}"] = top_k[i] + print("validation end", metrics) + + return {"topk": TopKMetricCB()} + + def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: + if self.task.focus_cond is not None: + self.task.focus_cond.step_focus_model(batch, train_it) + return super().train_batch(batch, epoch_idx, batch_idx, train_it) + + def _save_state(self, it): + if self.task.focus_cond is not None and self.task.focus_cond.focus_model is not None: + self.task.focus_cond.focus_model.save(pathlib.Path(self.cfg.log_dir)) + return super()._save_state(it) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index e916f732..91d65818 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -11,6 +11,7 @@ from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import Dataset +from torch_geometric.data import Data from gflownet.config import Config from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext, Graph @@ -62,16 +63,21 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) + def compute_reward_from_graph(self, graphs: List[Data]) -> Tensor: + batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) + batch.to(self.device) + preds = self.models["seh"](batch).reshape((-1,)).data.cpu() + preds[preds.isnan()] = 0 + return self.flat_reward_transform(preds).clip(1e-4, 100).reshape((-1,)) + def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] is_valid = torch.tensor([i is not None for i in graphs]).bool() if not is_valid.any(): return FlatRewards(torch.zeros((0, 1))), is_valid - batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) - batch.to(self.device) - preds = self.models["seh"](batch).reshape((-1,)).data.cpu() - preds[preds.isnan()] = 0 - preds = self.flat_reward_transform(preds).clip(1e-4, 100).reshape((-1, 1)) + + preds = self.compute_reward_from_graph(graphs).reshape((-1, 1)) + assert len(preds) == is_valid.sum() return FlatRewards(preds), is_valid @@ -109,10 +115,11 @@ class LittleSEHDataset(Dataset): To turn on, self `cfg.algo.offline_ratio > 0`""" - def __init__(self) -> None: + def __init__(self, smis) -> None: super().__init__() self.props: List[Tensor] = [] self.mols: List[Graph] = [] + self.smis = smis def setup(self, task, ctx): rdmols = [Chem.MolFromSmiles(i) for i in SOME_MOLS] @@ -174,10 +181,18 @@ def setup_task(self): def setup_data(self): super().setup_data() - self.training_data = LittleSEHDataset() + if self.cfg.task.seh.reduced_frag: + # The examples don't work with the 18 frags + self.training_data = LittleSEHDataset([]) + else: + self.training_data = LittleSEHDataset(SOME_MOLS) def setup_env_context(self): - self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.task.num_cond_dim) + self.ctx = FragMolBuildingEnvContext( + max_frags=self.cfg.algo.max_nodes, + num_cond_dim=self.task.num_cond_dim, + fragments=bengio2021flow.FRAGMENTS_18 if self.cfg.task.seh.reduced_frag else bengio2021flow.FRAGMENTS, + ) def setup(self): super().setup() diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index bd597c31..31d8f769 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -25,12 +25,38 @@ from gflownet.utils.transforms import to_logreward +def safe(f, x, default): + try: + return f(x) + except Exception: + return default + + +def mol2mw(mols: list[RDMol], is_valid: list[bool], default=1000): + molwts = torch.tensor([safe(Descriptors.MolWt, i, default) for i, v in zip(mols, is_valid) if v]) + molwts = ((300 - molwts) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 + return molwts + + +def mol2sas(mols: list[RDMol], is_valid: list[bool], default=10): + sas = torch.tensor([safe(sascore.calculateScore, i, default) for i, v in zip(mols, is_valid) if v]) + sas = (10 - sas) / 9 # Turn into a [0-1] reward + return sas + + +def mol2qed(mols: list[RDMol], is_valid: list[bool], default=0): + return torch.tensor([safe(QED.qed, i, 0) for i, v in zip(mols, is_valid) if v]) + + +aux_tasks = {"qed": mol2qed, "sa": mol2sas, "mw": mol2mw} + + class SEHMOOTask(SEHTask): """Sets up a multiobjective task where the rewards are (functions of): - - the the binding energy of a molecule to Soluble Epoxide Hydrolases. - - its QED - - its synthetic accessibility - - its molecular weight + - the binding energy of a molecule to Soluble Epoxide Hydrolases, + - its QED, + - its synthetic accessibility, + - and its molecular weight. The proxy is pretrained, and obtained from the original GFlowNet paper, see `gflownet.models.bengio2021flow`. """ @@ -170,41 +196,22 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] - is_valid = torch.tensor([i is not None for i in graphs]).bool() - if not is_valid.any(): - return FlatRewards(torch.zeros((0, len(self.objectives)))), is_valid - + assert len(graphs) == len(mols) + is_valid = [i is not None for i in graphs] + is_valid_t = torch.tensor(is_valid, dtype=torch.bool) + if not any(is_valid): + return FlatRewards(torch.zeros((0, len(self.objectives)))), is_valid_t else: flat_r: List[Tensor] = [] - if "seh" in self.objectives: - batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) - batch.to(self.device) - seh_preds = self.models["seh"](batch).reshape((-1,)).clip(1e-4, 100).data.cpu() / 8 - seh_preds[seh_preds.isnan()] = 0 - flat_r.append(seh_preds) - - def safe(f, x, default): - try: - return f(x) - except Exception: - return default - - if "qed" in self.objectives: - qeds = torch.tensor([safe(QED.qed, i, 0) for i, v in zip(mols, is_valid) if v.item()]) - flat_r.append(qeds) - - if "sa" in self.objectives: - sas = torch.tensor([safe(sascore.calculateScore, i, 10) for i, v in zip(mols, is_valid) if v.item()]) - sas = (10 - sas) / 9 # Turn into a [0-1] reward - flat_r.append(sas) - - if "mw" in self.objectives: - molwts = torch.tensor([safe(Descriptors.MolWt, i, 1000) for i, v in zip(mols, is_valid) if v.item()]) - molwts = ((300 - molwts) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 - flat_r.append(molwts) + for obj in self.objectives: + if obj == "seh": + flat_r.append(super().compute_reward_from_graph(graphs)) + else: + flat_r.append(aux_tasks[obj](mols, is_valid)) flat_rewards = torch.stack(flat_r, dim=1) - return FlatRewards(flat_rewards), is_valid + assert flat_rewards.shape[0] == len(mols) + return FlatRewards(flat_rewards), is_valid_t class SEHMOOFragTrainer(SEHFragTrainer): @@ -236,9 +243,6 @@ def setup_task(self): wrap_model=self._wrap_for_mp, ) - def setup_env_context(self): - self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.task.num_cond_dim) - def setup_model(self): if self.cfg.algo.method == "MOQL": self.model = GraphTransformerFragEnvelopeQL( @@ -253,16 +257,18 @@ def setup_model(self): def setup(self): super().setup() - self.sampling_hooks.append( - MultiObjectiveStatsHook( - 256, - self.cfg.log_dir, - compute_igd=True, - compute_pc_entropy=True, - compute_focus_accuracy=True if self.cfg.cond.focus_region.focus_type is not None else False, - focus_cosim=self.cfg.cond.focus_region.focus_cosim, + if self.cfg.task.seh_moo.online_pareto_front: + self.sampling_hooks.append( + MultiObjectiveStatsHook( + 256, + self.cfg.log_dir, + compute_igd=True, + compute_pc_entropy=True, + compute_focus_accuracy=True if self.cfg.cond.focus_region.focus_type is not None else False, + focus_cosim=self.cfg.cond.focus_region.focus_cosim, + ) ) - ) + self.to_terminate.append(self.sampling_hooks[-1].terminate) # instantiate preference and focus conditioning vectors for validation n_obj = len(self.cfg.task.seh_moo.objectives) @@ -297,7 +303,7 @@ def setup(self): elif cond_cfg.weighted_prefs.preference_type == "seeded_many": valid_preferences = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) else: - raise NotImplementedError(f"Unknown preference type {self.cfg.task.seh_moo.preference_type}") + raise NotImplementedError(f"Unknown preference type {cond_cfg.weighted_prefs.preference_type}") # TODO: this was previously reported, would be nice to serialize it # hps["fixed_focus_dirs"] = ( @@ -348,12 +354,6 @@ def _save_state(self, it): self.task.focus_cond.focus_model.save(pathlib.Path(self.cfg.log_dir)) return super()._save_state(it) - def run(self): - super().run() - for hook in self.sampling_hooks: - if hasattr(hook, "terminate"): - hook.terminate() - class RepeatedCondInfoDataset: def __init__(self, cond_info_vectors, repeat): diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 55e0159b..e60d742e 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -1,6 +1,9 @@ +import gc import os import pathlib -from typing import Any, Callable, Dict, List, NewType, Optional, Tuple +import shutil +import time +from typing import Any, Callable, Dict, List, NewType, Optional, Protocol, Tuple import numpy as np import torch @@ -31,6 +34,11 @@ class GFNAlgorithm: + updates: int = 0 + + def step(self): + self.updates += 1 + def compute_batch_losses( self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 ) -> Tuple[Tensor, Dict[str, Tensor]]: @@ -90,8 +98,13 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: raise NotImplementedError() +class Closable(Protocol): + def close(self): + pass + + class GFNTrainer: - def __init__(self, hps: Dict[str, Any]): + def __init__(self, hps: Dict[str, Any], print_hps=True): """A GFlowNet trainer. Contains the main training loop in `run` and should be subclassed. Parameters @@ -101,6 +114,8 @@ def __init__(self, hps: Dict[str, Any]): device: torch.device The torch device of the main worker. """ + self.print_hps = print_hps + self.to_terminate: List[Closable] = [] # self.setup should at least set these up: self.training_data: Dataset self.test_data: Dataset @@ -173,13 +188,14 @@ def _wrap_for_mp(self, obj, send_to_device=False): if send_to_device: obj.to(self.device) if self.cfg.num_workers > 0 and obj is not None: - placeholder = mp_object_wrapper( + wapper = mp_object_wrapper( obj, self.cfg.num_workers, cast_types=(gd.Batch, GraphActionCategorical, SeqBatch), pickle_messages=self.cfg.pickle_mp_messages, ) - return placeholder, torch.device("cpu") + self.to_terminate.append(wapper.terminate) + return wapper.placeholder, torch.device("cpu") else: return obj, self.device @@ -202,6 +218,7 @@ def build_training_data_loader(self) -> DataLoader: ratio=self.cfg.algo.offline_ratio, log_dir=str(pathlib.Path(self.cfg.log_dir) / "train"), random_action_prob=self.cfg.algo.train_random_action_prob, + det_after=self.cfg.algo.train_det_after, hindsight_ratio=self.cfg.replay.hindsight_ratio, ) for hook in self.sampling_hooks: @@ -272,11 +289,14 @@ def build_final_data_loader(self) -> DataLoader: ) def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: + tick = time.time() + self.model.train() try: loss, info = self.algo.compute_batch_losses(self.model, batch) if not torch.isfinite(loss): raise ValueError("loss is not finite") step_info = self.step(loss) + self.algo.step() if self._validate_parameters and not all([torch.isfinite(i).all() for i in self.model.parameters()]): raise ValueError("parameters are not finite") except ValueError as e: @@ -288,12 +308,16 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: info.update(step_info) if hasattr(batch, "extra_info"): info.update(batch.extra_info) + info["train_time"] = time.time() - tick return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} def evaluate_batch(self, batch: gd.Batch, epoch_idx: int = 0, batch_idx: int = 0) -> Dict[str, Any]: + tick = time.time() + self.model.eval() loss, info = self.algo.compute_batch_losses(self.model, batch) if hasattr(batch, "extra_info"): info.update(batch.extra_info) + info["eval_time"] = time.time() - tick return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} def run(self, logger=None): @@ -316,7 +340,14 @@ def run(self, logger=None): start = self.cfg.start_at_step + 1 num_training_steps = self.cfg.num_training_steps logger.info("Starting training") + start_time = time.time() for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)): + # the memory fragmentation or allocation keeps growing, how often should we clean up? + # is changing the allocation strategy helpful? + + if it % 1024 == 0: + gc.collect() + torch.cuda.empty_cache() epoch_idx = it // epoch_length batch_idx = it % epoch_length if self.replay_buffer is not None and len(self.replay_buffer) < self.replay_buffer.warmup: @@ -325,6 +356,8 @@ def run(self, logger=None): ) continue info = self.train_batch(batch.to(self.device), epoch_idx, batch_idx, it) + info["time_spent"] = time.time() - start_time + start_time = time.time() self.log(info, it, "train") if it % self.print_every == 0: logger.info(f"iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) @@ -344,24 +377,57 @@ def run(self, logger=None): self._save_state(num_training_steps) num_final_gen_steps = self.cfg.num_final_gen_steps + final_info = {} if num_final_gen_steps: logger.info(f"Generating final {num_final_gen_steps} batches ...") for it, batch in zip( - range(num_training_steps, num_training_steps + num_final_gen_steps + 1), + range(num_training_steps + 1, num_training_steps + num_final_gen_steps + 1), cycle(final_dl), ): - pass - logger.info("Final generation steps completed.") + if hasattr(batch, "extra_info"): + for k, v in batch.extra_info.items(): + if k not in final_info: + final_info[k] = [] + if hasattr(v, "item"): + v = v.item() + final_info[k].append(v) + if it % self.print_every == 0: + logger.info(f"Generating mols {it - num_training_steps}/{num_final_gen_steps}") + final_info = {k: np.mean(v) for k, v in final_info.items()} + + logger.info("Final generation steps completed - " + " ".join(f"{k}:{v:.2f}" for k, v in final_info.items())) + self.log(final_info, num_training_steps, "final") + + # for pypy and other GC having implementations, we need to manually clean up + del train_dl + del valid_dl + if self.cfg.num_final_gen_steps: + del final_dl + + def terminate(self): + for hook in self.sampling_hooks: + if hasattr(hook, "terminate") and hook.terminate not in self.to_terminate: + hook.terminate() + + for terminate in self.to_terminate: + terminate() def _save_state(self, it): - torch.save( - { - "models_state_dict": [self.model.state_dict()], - "cfg": self.cfg, - "step": it, - }, - open(pathlib.Path(self.cfg.log_dir) / "model_state.pt", "wb"), - ) + state = { + "models_state_dict": [self.model.state_dict()], + "cfg": self.cfg, + "step": it, + } + if self.sampling_model is not self.model: + state["sampling_model_state_dict"] = [self.sampling_model.state_dict()] + fn = pathlib.Path(self.cfg.log_dir) / "model_state.pt" + with open(fn, "wb") as fd: + torch.save( + state, + fd, + ) + if self.cfg.store_all_checkpoints: + shutil.copy(fn, pathlib.Path(self.cfg.log_dir) / f"model_state_{it}.pt") def log(self, info, index, key): if not hasattr(self, "_summary_writer"): @@ -369,6 +435,9 @@ def log(self, info, index, key): for k, v in info.items(): self._summary_writer.add_scalar(f"{key}_{k}", v, index) + def __del__(self): + self.terminate() + def cycle(it): while True: diff --git a/src/gflownet/utils/conditioning.py b/src/gflownet/utils/conditioning.py index acb7d12e..5d087d5f 100644 --- a/src/gflownet/utils/conditioning.py +++ b/src/gflownet/utils/conditioning.py @@ -53,9 +53,11 @@ def sample(self, n): cfg = self.cfg.cond.temperature beta = None if cfg.sample_dist == "constant": - assert isinstance(cfg.dist_params[0], float) - beta = np.array(cfg.dist_params[0]).repeat(n).astype(np.float32) - beta_enc = torch.zeros((n, cfg.num_thermometer_dim)) + if isinstance(cfg.dist_params[0], (float, int, np.int64, np.int32)): + beta = np.array(cfg.dist_params[0]).repeat(n).astype(np.float32) + beta_enc = torch.zeros((n, cfg.num_thermometer_dim)) + else: + raise ValueError(f"{cfg.dist_params[0]} is not a float)") else: if cfg.sample_dist == "gamma": loc, scale = cfg.dist_params @@ -101,11 +103,11 @@ def sample(self, n): elif self.cfg.preference_type == "seeded": preferences = torch.tensor(self.seeded_prefs).float().repeat(n, 1) elif self.cfg.preference_type == "dirichlet_exponential": - a = np.random.dirichlet([1] * self.num_objectives, n) + a = np.random.dirichlet([self.cfg.preference_param] * self.num_objectives, n) b = np.random.exponential(1, n)[:, None] preferences = Dirichlet(torch.tensor(a * b)).sample([1])[0].float() elif self.cfg.preference_type == "dirichlet": - m = Dirichlet(torch.FloatTensor([1.0] * self.num_objectives)) + m = Dirichlet(torch.FloatTensor([self.cfg.preference_param] * self.num_objectives)) preferences = m.sample([n]) else: raise ValueError(f"Unknown preference type {self.cfg.preference_type}") diff --git a/src/gflownet/utils/config.py b/src/gflownet/utils/config.py index db3d3905..54d0660d 100644 --- a/src/gflownet/utils/config.py +++ b/src/gflownet/utils/config.py @@ -47,6 +47,7 @@ class WeightedPreferencesConfig: - None: All rewards equally weighted""" preference_type: Optional[str] = "dirichlet" + preference_param: Optional[float] = 1.5 @dataclass diff --git a/src/gflownet/utils/metrics.py b/src/gflownet/utils/metrics.py index cc37c127..7f5dda4c 100644 --- a/src/gflownet/utils/metrics.py +++ b/src/gflownet/utils/metrics.py @@ -522,6 +522,22 @@ def inv_transform(self, arr): return self.scale * arr + self.loc +def all_are_tanimoto_different(thresh, fp, mode_fps, delta=16): + """ + Equivalent to `all(DataStructs.BulkTanimotoSimilarity(fp, mode_fps) < thresh)` but much faster. + """ + assert delta > 0 + s = 0 + n = len(mode_fps) + while s < n: + e = min(s + delta, n) + for i in DataStructs.BulkTanimotoSimilarity(fp, mode_fps[s:e]): + if i >= thresh: + return False + s = e + return True + + # Should be calculated per preference def compute_diverse_top_k(smiles, rewards, k, thresh=0.7): # mols is a list of (reward, mol) @@ -551,7 +567,7 @@ def get_topk(rewards, k): Rewards obtained after taking the convex combination. Shape: number_of_preferences x number_of_samples k : int - Tok k value + Top k value Returns ---------- @@ -565,6 +581,19 @@ def get_topk(rewards, k): return mean_topk +def top_k_diversity(fps, r, K): + x = [] + for i in np.argsort(r)[::-1]: + y = fps[i] + if y is None: + continue + x.append(y) + if len(x) >= K: + break + s = np.array([DataStructs.BulkTanimotoSimilarity(i, x) for i in x]) + return (np.sum(s) - len(x)) / (len(x) * len(x) - len(x)) # substract the diagonal + + if __name__ == "__main__": # Example for 2 dimensions # Point set: {(1,3), (2,2), (3,1)}, l = (0,0), u = (4,4) diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index 4743efa0..115bef3a 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -1,3 +1,4 @@ +import math import pathlib import queue import threading @@ -13,6 +14,10 @@ class MultiObjectiveStatsHook: + """ + This hook is multithreaded and the keep_alive object needs to be closed for graceful termination. + """ + def __init__( self, num_to_keep: int, @@ -55,9 +60,6 @@ def __init__( self.pareto_thread = threading.Thread(target=self._run_pareto_accumulation, daemon=True) self.pareto_thread.start() - def __del__(self): - self.stop.set() - def _hsri(self, x): assert x.ndim == 2, "x should have shape (num points, num objectives)" upper = np.zeros(x.shape[-1]) + self.hsri_epsilon @@ -71,15 +73,18 @@ def _hsri(self, x): def _run_pareto_accumulation(self): num_updates = 0 - while not self.stop.is_set(): + timeouts = 0 + while not self.stop.is_set() or timeouts < 200: try: r, smi, owid = self.pareto_queue.get(block=True, timeout=1) except queue.Empty: + timeouts += 1 continue except ConnectionError as e: print("Pareto Accumulation thread Queue ConnectionError", e) break + timeouts = 0 # accumulates pareto fronts across batches if self.pareto_front is None: p = self.pareto_front = r @@ -108,14 +113,19 @@ def _run_pareto_accumulation(self): if num_updates % self.save_every == 0: if self.pareto_queue.qsize() > 10: print("Warning: pareto metrics computation lagging") - torch.save( - { - "pareto_front": self.pareto_front, - "pareto_metrics": list(self.pareto_metrics), - "pareto_front_smi": self.pareto_front_smi, - }, - open(self.log_path, "wb"), - ) + self._save() + self._save() + + def _save(self): + with open(self.log_path, "wb") as fd: + torch.save( + { + "pareto_front": self.pareto_front, + "pareto_metrics": list(self.pareto_metrics), + "pareto_front_smi": self.pareto_front_smi, + }, + fd, + ) def __call__(self, trajs, rewards, flat_rewards, cond_info): # locally (in-process) accumulate flat rewards to build a better pareto estimate @@ -227,3 +237,45 @@ def finalize(self): top_ks = [np.mean(sorted(i)[-self.k :]) for i in repeats.values()] assert len(top_ks) == self.num_preferences # Make sure we got all of them? return top_ks + + +class RewardPercentilesHook: + """ + Calculate percentiles of the reward. + + Parameters + ---------- + idx: List[float] + The percentiles to calculate. Should be in the range [0, 1]. + Default: [1.0, 0.75, 0.5, 0.25, 0] + """ + + def __init__(self, percentiles=None): + if percentiles is None: + percentiles = [1.0, 0.75, 0.5, 0.25, 0] + self.percentiles = percentiles + + def __call__(self, trajs, rewards, flat_rewards, cond_info): + x = np.sort(flat_rewards.numpy(), axis=0) + ret = {} + y = np.sort(rewards.numpy()) + for p in self.percentiles: + f = max(min(math.floor(x.shape[0] * p), x.shape[0] - 1), 0) + for j in range(x.shape[1]): + ret[f"percentile_flat_reward_{j}_{p:.2f}"] = x[f, j] + ret[f"percentile_reward_{p:.2f}%"] = y[f] + return ret + + +class TrajectoryLengthHook: + """ + Report the average trajectory length. + """ + + def __init__(self) -> None: + pass + + def __call__(self, trajs, rewards, flat_rewards, cond_info): + ret = {} + ret["sample_len"] = sum([len(i["traj"]) for i in trajs]) / len(trajs) + return ret diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index 9687087e..df13b565 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -106,9 +106,6 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo self.thread = threading.Thread(target=self.run, daemon=True) self.thread.start() - def __del__(self): - self.stop.set() - def encode(self, m): if self.pickle_messages: return pickle.dumps(m) @@ -123,14 +120,18 @@ def to_cpu(self, i): return i.detach().to(torch.device("cpu")) if isinstance(i, self.cuda_types) else i def run(self): - while not self.stop.is_set(): + timeouts = 0 + + while not self.stop.is_set() or timeouts < 500: for qi, q in enumerate(self.in_queues): try: r = self.decode(q.get(True, 1e-5)) except queue.Empty: + timeouts += 1 continue except ConnectionError: break + timeouts = 0 attr, args, kwargs = r f = getattr(self.obj, attr) args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] @@ -154,6 +155,9 @@ def run(self): msg = self.to_cpu(result) self.out_queues[qi].put(self.encode(msg)) + def terminate(self): + self.stop.set() + def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False): """Construct a multiprocessing object proxy for torch DataLoaders so @@ -190,4 +194,4 @@ def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = Fals A placeholder object whose method calls route arguments to the main process """ - return MPObjectProxy(obj, num_workers, cast_types, pickle_messages).placeholder + return MPObjectProxy(obj, num_workers, cast_types, pickle_messages) diff --git a/tests/test_subtb.py b/tests/test_subtb.py index c4841689..d89ea4bf 100644 --- a/tests/test_subtb.py +++ b/tests/test_subtb.py @@ -1,8 +1,11 @@ from functools import reduce +import networkx as nx +import numpy as np import torch from gflownet.algo.trajectory_balance import subTB +from gflownet.envs.frag_mol_env import NCounter def subTB_ref(P_F, P_B, F): @@ -27,3 +30,47 @@ def test_subTB(): P_B = torch.randint(1, 10, (T,)) F = torch.randint(1, 10, (T + 1,)) assert (subTB(F, P_F - P_B) == subTB_ref(P_F, P_B, F)).all() + + +def test_n(): + n = NCounter() + x = 0 + for i in range(1, 10): + x += np.log(i) + assert np.isclose(n.lfac(i), x) + + assert np.isclose(n.lcomb(5, 2), np.log(10)) + + +def test_g1(): + n = NCounter() + g = nx.Graph() + for i in range(3): + g.add_node(i) + g.add_edge(0, 1) + g.add_edge(1, 2) + rg = n.root_tree(g, 0) + assert n.f(rg, 0) == 0 + rg = n.root_tree(g, 2) + assert n.f(rg, 2) == 0 + rg = n.root_tree(g, 1) + assert np.isclose(n.f(rg, 1), np.log(2)) + + assert np.isclose(n(g), np.log(4)) + + +def test_g(): + n = NCounter() + g = nx.Graph() + for i in range(3): + g.add_node(i) + g.add_edge(0, 1) + g.add_edge(1, 2, weight=2) + rg = n.root_tree(g, 0) + assert n.f(rg, 0) == 0 + rg = n.root_tree(g, 2) + assert np.isclose(n.f(rg, 2), np.log(2)) + rg = n.root_tree(g, 1) + assert np.isclose(n.f(rg, 1), np.log(3)) + + assert np.isclose(n(g), np.log(6))