Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Feb 26, 2024
1 parent aa15f27 commit 6281907
Show file tree
Hide file tree
Showing 18 changed files with 152 additions and 43 deletions.
12 changes: 12 additions & 0 deletions docs/implementation_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,15 @@ The code contains a specific categorical distribution type for graph actions, `G
Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor.

The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.

### Min/max trajectory length

The current way min/max trajectory lengths are handled is somewhat contrived (contributions welcome!) for historical reasons.

- min length: a `GraphBuildingEnvContext`'s `graph_to_Data(g, t)` receives the timestep as its second argument. The responsibility of masking the stop action is left to the context to enforce _minimum_ trajectory lengths.
- max length: the `GraphSampler` class enforces maximum length and maximum number of nodes by terminating the trajectory if either condition is met.
- max size: both `MolBuildingEnvContext` and `FragMolBuildingEnvContext` implement a `max_nodes`/`max_frags` property that is used to mask the `AddNode` action.

Sequence environments differ somewhat, it's left to the `SeqTransformer` class to mask the stop action using the `min_len` parameter.

To output fixed-length trajectories it should be sufficient to set `cfg.algo.min_len` and `cfg.algo.max_len` to the same value. Note that in some cases, e.g. when building fragment graphs, the agent may still output trajectories that are shorter than `min_len` by combining two fragments of degree one (leaving no valid action but to stop).
2 changes: 1 addition & 1 deletion src/gflownet/algo/advantage_actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def construct_batch(self, trajs, cond_info, log_rewards):
batch: gd.Batch
A (CPU) Batch object with relevant attributes added
"""
torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]]
torch_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"])]
actions = [
self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
]
Expand Down
4 changes: 4 additions & 0 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class AlgoConfig:
The name of the algorithm to use (e.g. "TB")
global_batch_size : int
The batch size for training
min_len: int
If >0, prevents the agent from using the Stop action before min_len steps (trajectories may still end for
other reasons, but generally setting min_len==max_len should produce fixed length trajectories).
max_len : int
The maximum length of a trajectory
max_nodes : int
Expand Down Expand Up @@ -124,6 +127,7 @@ class AlgoConfig:

method: str = "TB"
global_batch_size: int = 64
min_len: int = 0
max_len: int = 128
max_nodes: int = 128
max_edges: int = 128
Expand Down
2 changes: 1 addition & 1 deletion src/gflownet/algo/envelope_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def construct_batch(self, trajs, cond_info, log_rewards):
batch: gd.Batch
A (CPU) Batch object with relevant attributes added
"""
torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]]
torch_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"])]
actions = [
self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
]
Expand Down
19 changes: 12 additions & 7 deletions src/gflownet/algo/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,17 @@ def construct_batch(self, trajs, cond_info, log_rewards):
"""
if not self.correct_idempotent:
# For every s' (i.e. every state except the first of each trajectory), enumerate parents
parents = [[relabel(*i) for i in self.env.parents(i[0])] for tj in trajs for i in tj["traj"][1:]]
parents = [
([relabel(*i) for i in self.env.parents(i[0])], t) for tj in trajs for t, i in enumerate(tj["traj"][1:])
]
# convert parents to Data
parent_graphs = [self.ctx.graph_to_Data(pstate) for parent in parents for pact, pstate in parent]
parent_graphs = [self.ctx.graph_to_Data(pstate, t) for parent, t in parents for pact, pstate in parent]
else:
# Here we again enumerate parents
states = [i[0] for tj in trajs for i in tj["traj"][1:]]
base_parents = [[relabel(*i) for i in self.env.parents(i)] for i in states]
states = [(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"][1:])]
base_parents = [([relabel(*i) for i in self.env.parents(i)], t) for i, t in states]
base_parent_graphs = [
[self.ctx.graph_to_Data(pstate) for pact, pstate in parent_set] for parent_set in base_parents
[self.ctx.graph_to_Data(pstate, t) for pact, pstate in parent_set] for parent_set, t in base_parents
]
parents = []
parent_graphs = []
Expand All @@ -103,9 +105,12 @@ def construct_batch(self, trajs, cond_info, log_rewards):
parent_actions = [pact for parent in parents for pact, pstate in parent]
parent_actionidcs = [self.ctx.GraphAction_to_aidx(gdata, a) for gdata, a in zip(parent_graphs, parent_actions)]
# convert state to Data
state_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"][1:]]
state_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"][1:])]
terminal_actions = [
self.ctx.GraphAction_to_aidx(self.ctx.graph_to_Data(tj["traj"][-1][0]), tj["traj"][-1][1]) for tj in trajs
self.ctx.GraphAction_to_aidx(
self.ctx.graph_to_Data(tj["traj"][-1][0], len(tj["traj"]) - 1), tj["traj"][-1][1]
)
for tj in trajs
]

# Create a batch from [*parents, *states]. This order will make it easier when computing the loss
Expand Down
20 changes: 17 additions & 3 deletions src/gflownet/algo/graph_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@ class GraphSampler:
"""A helper class to sample from GraphActionCategorical-producing models"""

def __init__(
self, ctx, env, max_len, max_nodes, rng, sample_temp=1, correct_idempotent=False, pad_with_terminal_state=False
self,
ctx,
env,
max_len,
max_nodes,
rng,
sample_temp=1,
correct_idempotent=False,
pad_with_terminal_state=False,
# min_len=0,
):
"""
Parameters
Expand Down Expand Up @@ -62,6 +71,7 @@ def __init__(
self.sanitize_samples = True
self.correct_idempotent = correct_idempotent
self.pad_with_terminal_state = pad_with_terminal_state
self.consider_masks_complete = ctx.consider_masks_complete if hasattr(ctx, "consider_masks_complete") else False

def sample_from_model(
self, model: nn.Module, n: int, cond_info: Tensor, dev: torch.device, random_action_prob: float = 0.0
Expand Down Expand Up @@ -108,7 +118,7 @@ def not_done(lst):

for t in range(self.max_len):
# Construct graphs for the trajectories that aren't yet done
torch_graphs = [self.ctx.graph_to_Data(i) for i in not_done(graphs)]
torch_graphs = [self.ctx.graph_to_Data(i, t) for i in not_done(graphs)]
not_done_mask = torch.tensor(done, device=dev).logical_not()
# Forward pass to get GraphActionCategorical
# Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does.
Expand Down Expand Up @@ -153,7 +163,11 @@ def not_done(lst):
# self.env.step can raise AssertionError if the action is illegal
gp = self.env.step(graphs[i], graph_actions[j])
assert len(gp.nodes) <= self.max_nodes
except AssertionError:
except AssertionError as e:
if self.consider_masks_complete:
# If masks are considered complete, then we can safely say that we've encountered a bug
# since the agent should only be able to take legal actions (that would not raise an error)
raise e
done[i] = True
data[i]["is_valid"] = False
bck_logprob[i].append(torch.tensor([1.0], device=dev).log())
Expand Down
2 changes: 1 addition & 1 deletion src/gflownet/algo/soft_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def construct_batch(self, trajs, cond_info, log_rewards):
batch: gd.Batch
A (CPU) Batch object with relevant attributes added
"""
torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]]
torch_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"])]
actions = [
self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
]
Expand Down
6 changes: 4 additions & 2 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,12 @@ def construct_batch(self, trajs, cond_info, log_rewards):
A (CPU) Batch object with relevant attributes added
"""
if self.model_is_autoregressive:
torch_graphs = [self.ctx.graph_to_Data(tj["traj"][-1][0]) for tj in trajs]
# Since we're passing the entire sequence to an autoregressive model, it becomes its responsibility to deal
# with `t` (which is always just len(s)).
torch_graphs = [self.ctx.graph_to_Data(tj["traj"][-1][0], t=0) for tj in trajs]
actions = [self.ctx.GraphAction_to_aidx(g, i[1]) for g, tj in zip(torch_graphs, trajs) for i in tj["traj"]]
else:
torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]]
torch_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"])]
actions = [
self.ctx.GraphAction_to_aidx(g, a)
for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
Expand Down
22 changes: 19 additions & 3 deletions src/gflownet/envs/frag_mol_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from math import log
from typing import List, Tuple
from typing import List, Optional, Tuple

import networkx as nx
import numpy as np
Expand All @@ -24,7 +24,14 @@ class FragMolBuildingEnvContext(GraphBuildingEnvContext):
fragments. Masks ensure that the agent can only perform chemically valid attachments.
"""

def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tuple[str, List[int]]] = None):
def __init__(
self,
max_frags: int = 9,
num_cond_dim: int = 0,
fragments: Optional[List[Tuple[str, List[int]]]] = None,
min_len: int = 0,
max_len: Optional[int] = None,
):
"""Construct a fragment environment
Parameters
----------
Expand All @@ -37,6 +44,8 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu
the fragments of Bengio et al., 2021.
"""
self.max_frags = max_frags
self.min_len = min_len
self.max_len = max_len
if fragments is None:
smi, stems = zip(*bengio2021flow.FRAGMENTS)
else:
Expand Down Expand Up @@ -79,6 +88,12 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu
self.num_cond_dim = num_cond_dim
self.edges_are_duplicated = True
self.edges_are_unordered = False
# This flags says that we should be able to trust the masks encoded by graph_to_Data as a ground truth when
# determining if an action is valid or not. In other words,
# - actions produced by this context should always be valid
# - masks produced by this context have the same shape as the logit tensors (e.g. we should be able to use them
# to compute a uniform policy)
self.consider_masks_complete = True
self.fail_on_missing_attr = True

# Order in which models have to output logits
Expand Down Expand Up @@ -179,7 +194,7 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int
col = 1
return (type_idx, int(row), int(col))

def graph_to_Data(self, g: Graph) -> gd.Data:
def graph_to_Data(self, g: Graph, t: int = 0) -> gd.Data:
"""Convert a networkx Graph to a torch geometric Data instance
Parameters
----------
Expand Down Expand Up @@ -260,6 +275,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
)
add_node_mask = add_node_mask * np.ones((x.shape[0], self.num_new_node_values), np.float32)
stop_mask = zeros((1, 1)) if has_unfilled_attach or not len(g) else np.ones((1, 1), np.float32)
stop_mask = stop_mask * ((t >= self.min_len) + (add_node_mask.sum() == 0)).clip(max=1)

return gd.Data(
**{
Expand Down
4 changes: 3 additions & 1 deletion src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,12 +892,14 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int
"""
raise NotImplementedError()

def graph_to_Data(self, g: Graph) -> gd.Data:
def graph_to_Data(self, g: Graph, t: int) -> gd.Data:
"""Convert a networkx Graph to a torch geometric Data instance
Parameters
----------
g: Graph
A graph instance.
t:
The current timestep (may be ignored by some contexts)
Returns
-------
Expand Down
7 changes: 4 additions & 3 deletions src/gflownet/envs/mol_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
num_rw_feat=0,
max_nodes=None,
max_edges=None,
min_time=0,
):
"""An env context for building molecules atom-by-atom and bond-by-bond.
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
self.num_rw_feat = num_rw_feat
self.max_nodes = max_nodes
self.max_edges = max_edges
self.min_time = 0

self.default_wildcard_replacement = "C"
self.negative_attrs = ["fill_wildcard"]
Expand Down Expand Up @@ -255,7 +257,7 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int
raise ValueError(f"Unknown action type {action.action}")
return (type_idx, int(row), int(col))

def graph_to_Data(self, g: Graph) -> gd.Data:
def graph_to_Data(self, g: Graph, t: int = 0) -> gd.Data:
"""Convert a networkx Graph to a torch geometric Data instance"""
x = np.zeros((max(1, len(g.nodes)), self.num_node_dim - self.num_rw_feat), dtype=np.float32)
x[0, -1] = len(g.nodes) == 0
Expand Down Expand Up @@ -376,8 +378,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
edge_index=edge_index,
edge_attr=edge_attr,
non_edge_index=non_edge_index.astype(np.int64).reshape((-1, 2)).T,
stop_mask=np.ones((1, 1), dtype=np.float32)
* (len(g.nodes) > 0), # Can only stop if there's at least a node
stop_mask=np.ones((1, 1)) * (len(g.nodes) > 0) * (t >= self.min_time), # Only stop if there's 1+ nodes
add_node_mask=add_node_mask,
set_node_attr_mask=set_node_attr_mask,
add_edge_mask=np.ones(
Expand Down
8 changes: 6 additions & 2 deletions src/gflownet/envs/seq_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def __init__(self):
def __repr__(self):
return "".join(map(str, self.seq))

def __len__(self) -> int:
return len(self.seq)

@property
def nodes(self):
return self.seq
Expand Down Expand Up @@ -84,7 +87,7 @@ class AutoregressiveSeqBuildingContext(GraphBuildingEnvContext):
This context gets an agent to generate sequences of tokens from left to right, i.e. in an autoregressive fashion.
"""

def __init__(self, alphabet: Sequence[str], num_cond_dim=0):
def __init__(self, alphabet: Sequence[str], num_cond_dim=0, min_len=0):
self.alphabet = alphabet
self.action_type_order = [GraphActionType.Stop, GraphActionType.AddNode]

Expand All @@ -93,6 +96,7 @@ def __init__(self, alphabet: Sequence[str], num_cond_dim=0):
self.pad_token = len(alphabet) + 1
self.num_actions = len(alphabet) + 1 # Alphabet + Stop
self.num_cond_dim = num_cond_dim
self.min_len = min_len

def aidx_to_GraphAction(self, g: Data, action_idx: Tuple[int, int, int], fwd: bool = True) -> GraphAction:
# Since there's only one "object" per timestep to act upon (in graph parlance), the row is always == 0
Expand All @@ -115,7 +119,7 @@ def GraphAction_to_aidx(self, g: Data, action: GraphAction) -> Tuple[int, int, i
raise ValueError(action)
return (type_idx, 0, int(col))

def graph_to_Data(self, g: Graph):
def graph_to_Data(self, g: Graph, t: int):
s: Seq = g # type: ignore
return torch.tensor([self.bos_token] + s.seq, dtype=torch.long)

Expand Down
6 changes: 6 additions & 0 deletions src/gflownet/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,16 @@ class ModelConfig:
The number of layers in the model
num_emb : int
The number of dimensions of the embedding
dropout : float
The dropout probability in intermediate layers
separate_pB : bool
If true, constructs the backward policy using a separate model (this effectively ~doubles the number of
parameters, all other things being equal)
"""

num_layers: int = 3
num_emb: int = 128
dropout: float = 0
do_separate_p_b: bool = False
graph_transformer: GraphTransformerConfig = GraphTransformerConfig()
seq_transformer: SeqTransformerConfig = SeqTransformerConfig()
Loading

0 comments on commit 6281907

Please sign in to comment.