Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate P_B trunk & min_len #107

Draft
wants to merge 2 commits into
base: trunk
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/implementation_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,15 @@ The code contains a specific categorical distribution type for graph actions, `G
Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor.

The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.

### Min/max trajectory length

The current way min/max trajectory lengths are handled is somewhat contrived (contributions welcome!) for historical reasons.

- min length: a `GraphBuildingEnvContext`'s `graph_to_Data(g, t)` receives the timestep as its second argument. The responsibility of masking the stop action is left to the context to enforce _minimum_ trajectory lengths.
- max length: the `GraphSampler` class enforces maximum length and maximum number of nodes by terminating the trajectory if either condition is met.
- max size: both `MolBuildingEnvContext` and `FragMolBuildingEnvContext` implement a `max_nodes`/`max_frags` property that is used to mask the `AddNode` action.

Sequence environments differ somewhat, it's left to the `SeqTransformer` class to mask the stop action using the `min_len` parameter.

To output fixed-length trajectories it should be sufficient to set `cfg.algo.min_len` and `cfg.algo.max_len` to the same value. Note that in some cases, e.g. when building fragment graphs, the agent may still output trajectories that are shorter than `min_len` by combining two fragments of degree one (leaving no valid action but to stop).
2 changes: 1 addition & 1 deletion src/gflownet/algo/advantage_actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def construct_batch(self, trajs, cond_info, log_rewards):
batch: gd.Batch
A (CPU) Batch object with relevant attributes added
"""
torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]]
torch_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"])]
actions = [
self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
]
Expand Down
4 changes: 4 additions & 0 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class AlgoConfig:
The name of the algorithm to use (e.g. "TB")
global_batch_size : int
The batch size for training
min_len: int
If >0, prevents the agent from using the Stop action before min_len steps (trajectories may still end for
other reasons, but generally setting min_len==max_len should produce fixed length trajectories).
max_len : int
The maximum length of a trajectory
max_nodes : int
Expand Down Expand Up @@ -124,6 +127,7 @@ class AlgoConfig:

method: str = "TB"
global_batch_size: int = 64
min_len: int = 0
max_len: int = 128
max_nodes: int = 128
max_edges: int = 128
Expand Down
2 changes: 1 addition & 1 deletion src/gflownet/algo/envelope_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def construct_batch(self, trajs, cond_info, log_rewards):
batch: gd.Batch
A (CPU) Batch object with relevant attributes added
"""
torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]]
torch_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"])]
actions = [
self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
]
Expand Down
19 changes: 12 additions & 7 deletions src/gflownet/algo/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,17 @@ def construct_batch(self, trajs, cond_info, log_rewards):
"""
if not self.correct_idempotent:
# For every s' (i.e. every state except the first of each trajectory), enumerate parents
parents = [[relabel(*i) for i in self.env.parents(i[0])] for tj in trajs for i in tj["traj"][1:]]
parents = [
([relabel(*i) for i in self.env.parents(i[0])], t) for tj in trajs for t, i in enumerate(tj["traj"][1:])
]
# convert parents to Data
parent_graphs = [self.ctx.graph_to_Data(pstate) for parent in parents for pact, pstate in parent]
parent_graphs = [self.ctx.graph_to_Data(pstate, t) for parent, t in parents for pact, pstate in parent]
else:
# Here we again enumerate parents
states = [i[0] for tj in trajs for i in tj["traj"][1:]]
base_parents = [[relabel(*i) for i in self.env.parents(i)] for i in states]
states = [(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"][1:])]
base_parents = [([relabel(*i) for i in self.env.parents(i)], t) for i, t in states]
base_parent_graphs = [
[self.ctx.graph_to_Data(pstate) for pact, pstate in parent_set] for parent_set in base_parents
[self.ctx.graph_to_Data(pstate, t) for pact, pstate in parent_set] for parent_set, t in base_parents
]
parents = []
parent_graphs = []
Expand All @@ -103,9 +105,12 @@ def construct_batch(self, trajs, cond_info, log_rewards):
parent_actions = [pact for parent in parents for pact, pstate in parent]
parent_actionidcs = [self.ctx.GraphAction_to_aidx(gdata, a) for gdata, a in zip(parent_graphs, parent_actions)]
# convert state to Data
state_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"][1:]]
state_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"][1:])]
terminal_actions = [
self.ctx.GraphAction_to_aidx(self.ctx.graph_to_Data(tj["traj"][-1][0]), tj["traj"][-1][1]) for tj in trajs
self.ctx.GraphAction_to_aidx(
self.ctx.graph_to_Data(tj["traj"][-1][0], len(tj["traj"]) - 1), tj["traj"][-1][1]
)
for tj in trajs
]

# Create a batch from [*parents, *states]. This order will make it easier when computing the loss
Expand Down
20 changes: 17 additions & 3 deletions src/gflownet/algo/graph_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@ class GraphSampler:
"""A helper class to sample from GraphActionCategorical-producing models"""

def __init__(
self, ctx, env, max_len, max_nodes, rng, sample_temp=1, correct_idempotent=False, pad_with_terminal_state=False
self,
ctx,
env,
max_len,
max_nodes,
rng,
sample_temp=1,
correct_idempotent=False,
pad_with_terminal_state=False,
# min_len=0,
):
"""
Parameters
Expand Down Expand Up @@ -62,6 +71,7 @@ def __init__(
self.sanitize_samples = True
self.correct_idempotent = correct_idempotent
self.pad_with_terminal_state = pad_with_terminal_state
self.consider_masks_complete = ctx.consider_masks_complete if hasattr(ctx, "consider_masks_complete") else False

def sample_from_model(
self, model: nn.Module, n: int, cond_info: Tensor, dev: torch.device, random_action_prob: float = 0.0
Expand Down Expand Up @@ -108,7 +118,7 @@ def not_done(lst):

for t in range(self.max_len):
# Construct graphs for the trajectories that aren't yet done
torch_graphs = [self.ctx.graph_to_Data(i) for i in not_done(graphs)]
torch_graphs = [self.ctx.graph_to_Data(i, t) for i in not_done(graphs)]
not_done_mask = torch.tensor(done, device=dev).logical_not()
# Forward pass to get GraphActionCategorical
# Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does.
Expand Down Expand Up @@ -153,7 +163,11 @@ def not_done(lst):
# self.env.step can raise AssertionError if the action is illegal
gp = self.env.step(graphs[i], graph_actions[j])
assert len(gp.nodes) <= self.max_nodes
except AssertionError:
except AssertionError as e:
if self.consider_masks_complete:
# If masks are considered complete, then we can safely say that we've encountered a bug
# since the agent should only be able to take legal actions (that would not raise an error)
raise e
done[i] = True
data[i]["is_valid"] = False
bck_logprob[i].append(torch.tensor([1.0], device=dev).log())
Expand Down
2 changes: 1 addition & 1 deletion src/gflownet/algo/soft_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def construct_batch(self, trajs, cond_info, log_rewards):
batch: gd.Batch
A (CPU) Batch object with relevant attributes added
"""
torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]]
torch_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"])]
actions = [
self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
]
Expand Down
6 changes: 4 additions & 2 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,12 @@ def construct_batch(self, trajs, cond_info, log_rewards):
A (CPU) Batch object with relevant attributes added
"""
if self.model_is_autoregressive:
torch_graphs = [self.ctx.graph_to_Data(tj["traj"][-1][0]) for tj in trajs]
# Since we're passing the entire sequence to an autoregressive model, it becomes its responsibility to deal
# with `t` (which is always just len(s)).
torch_graphs = [self.ctx.graph_to_Data(tj["traj"][-1][0], t=0) for tj in trajs]
actions = [self.ctx.GraphAction_to_aidx(g, i[1]) for g, tj in zip(torch_graphs, trajs) for i in tj["traj"]]
else:
torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]]
torch_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"])]
actions = [
self.ctx.GraphAction_to_aidx(g, a)
for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
Expand Down
22 changes: 19 additions & 3 deletions src/gflownet/envs/frag_mol_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from math import log
from typing import List, Tuple
from typing import List, Optional, Tuple

import networkx as nx
import numpy as np
Expand All @@ -24,7 +24,14 @@ class FragMolBuildingEnvContext(GraphBuildingEnvContext):
fragments. Masks ensure that the agent can only perform chemically valid attachments.
"""

def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tuple[str, List[int]]] = None):
def __init__(
self,
max_frags: int = 9,
num_cond_dim: int = 0,
fragments: Optional[List[Tuple[str, List[int]]]] = None,
min_len: int = 0,
max_len: Optional[int] = None,
):
"""Construct a fragment environment
Parameters
----------
Expand All @@ -37,6 +44,8 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu
the fragments of Bengio et al., 2021.
"""
self.max_frags = max_frags
self.min_len = min_len
self.max_len = max_len
if fragments is None:
smi, stems = zip(*bengio2021flow.FRAGMENTS)
else:
Expand Down Expand Up @@ -79,6 +88,12 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu
self.num_cond_dim = num_cond_dim
self.edges_are_duplicated = True
self.edges_are_unordered = False
# This flags says that we should be able to trust the masks encoded by graph_to_Data as a ground truth when
# determining if an action is valid or not. In other words,
# - actions produced by this context should always be valid
# - masks produced by this context have the same shape as the logit tensors (e.g. we should be able to use them
# to compute a uniform policy)
self.consider_masks_complete = True
self.fail_on_missing_attr = True

# Order in which models have to output logits
Expand Down Expand Up @@ -179,7 +194,7 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int
col = 1
return (type_idx, int(row), int(col))

def graph_to_Data(self, g: Graph) -> gd.Data:
def graph_to_Data(self, g: Graph, t: int = 0) -> gd.Data:
"""Convert a networkx Graph to a torch geometric Data instance
Parameters
----------
Expand Down Expand Up @@ -260,6 +275,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
)
add_node_mask = add_node_mask * np.ones((x.shape[0], self.num_new_node_values), np.float32)
stop_mask = zeros((1, 1)) if has_unfilled_attach or not len(g) else np.ones((1, 1), np.float32)
stop_mask = stop_mask * ((t >= self.min_len) + (add_node_mask.sum() == 0)).clip(max=1)

return gd.Data(
**{
Expand Down
4 changes: 3 additions & 1 deletion src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,12 +892,14 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int
"""
raise NotImplementedError()

def graph_to_Data(self, g: Graph) -> gd.Data:
def graph_to_Data(self, g: Graph, t: int) -> gd.Data:
"""Convert a networkx Graph to a torch geometric Data instance
Parameters
----------
g: Graph
A graph instance.
t:
The current timestep (may be ignored by some contexts)

Returns
-------
Expand Down
7 changes: 4 additions & 3 deletions src/gflownet/envs/mol_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
num_rw_feat=0,
max_nodes=None,
max_edges=None,
min_time=0,
):
"""An env context for building molecules atom-by-atom and bond-by-bond.

Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
self.num_rw_feat = num_rw_feat
self.max_nodes = max_nodes
self.max_edges = max_edges
self.min_time = 0

self.default_wildcard_replacement = "C"
self.negative_attrs = ["fill_wildcard"]
Expand Down Expand Up @@ -255,7 +257,7 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int
raise ValueError(f"Unknown action type {action.action}")
return (type_idx, int(row), int(col))

def graph_to_Data(self, g: Graph) -> gd.Data:
def graph_to_Data(self, g: Graph, t: int = 0) -> gd.Data:
"""Convert a networkx Graph to a torch geometric Data instance"""
x = np.zeros((max(1, len(g.nodes)), self.num_node_dim - self.num_rw_feat), dtype=np.float32)
x[0, -1] = len(g.nodes) == 0
Expand Down Expand Up @@ -376,8 +378,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
edge_index=edge_index,
edge_attr=edge_attr,
non_edge_index=non_edge_index.astype(np.int64).reshape((-1, 2)).T,
stop_mask=np.ones((1, 1), dtype=np.float32)
* (len(g.nodes) > 0), # Can only stop if there's at least a node
stop_mask=np.ones((1, 1)) * (len(g.nodes) > 0) * (t >= self.min_time), # Only stop if there's 1+ nodes
add_node_mask=add_node_mask,
set_node_attr_mask=set_node_attr_mask,
add_edge_mask=np.ones(
Expand Down
8 changes: 6 additions & 2 deletions src/gflownet/envs/seq_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def __init__(self):
def __repr__(self):
return "".join(map(str, self.seq))

def __len__(self) -> int:
return len(self.seq)

@property
def nodes(self):
return self.seq
Expand Down Expand Up @@ -84,7 +87,7 @@ class AutoregressiveSeqBuildingContext(GraphBuildingEnvContext):
This context gets an agent to generate sequences of tokens from left to right, i.e. in an autoregressive fashion.
"""

def __init__(self, alphabet: Sequence[str], num_cond_dim=0):
def __init__(self, alphabet: Sequence[str], num_cond_dim=0, min_len=0):
self.alphabet = alphabet
self.action_type_order = [GraphActionType.Stop, GraphActionType.AddNode]

Expand All @@ -93,6 +96,7 @@ def __init__(self, alphabet: Sequence[str], num_cond_dim=0):
self.pad_token = len(alphabet) + 1
self.num_actions = len(alphabet) + 1 # Alphabet + Stop
self.num_cond_dim = num_cond_dim
self.min_len = min_len

def aidx_to_GraphAction(self, g: Data, action_idx: Tuple[int, int, int], fwd: bool = True) -> GraphAction:
# Since there's only one "object" per timestep to act upon (in graph parlance), the row is always == 0
Expand All @@ -115,7 +119,7 @@ def GraphAction_to_aidx(self, g: Data, action: GraphAction) -> Tuple[int, int, i
raise ValueError(action)
return (type_idx, 0, int(col))

def graph_to_Data(self, g: Graph):
def graph_to_Data(self, g: Graph, t: int):
s: Seq = g # type: ignore
return torch.tensor([self.bos_token] + s.seq, dtype=torch.long)

Expand Down
6 changes: 6 additions & 0 deletions src/gflownet/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,16 @@ class ModelConfig:
The number of layers in the model
num_emb : int
The number of dimensions of the embedding
dropout : float
The dropout probability in intermediate layers
separate_pB : bool
If true, constructs the backward policy using a separate model (this effectively ~doubles the number of
parameters, all other things being equal)
"""

num_layers: int = 3
num_emb: int = 128
dropout: float = 0
do_separate_p_b: bool = False
graph_transformer: GraphTransformerConfig = GraphTransformerConfig()
seq_transformer: SeqTransformerConfig = SeqTransformerConfig()
Loading
Loading