diff --git a/docs/implementation_notes.md b/docs/implementation_notes.md index 600bb1d4..df463d4d 100644 --- a/docs/implementation_notes.md +++ b/docs/implementation_notes.md @@ -34,3 +34,15 @@ The code contains a specific categorical distribution type for graph actions, `G Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor. The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution. + +### Min/max trajectory length + +The current way min/max trajectory lengths are handled is somewhat contrived (contributions welcome!) for historical reasons. + +- min length: a `GraphBuildingEnvContext`'s `graph_to_Data(g, t)` receives the timestep as its second argument. The responsibility of masking the stop action is left to the context to enforce _minimum_ trajectory lengths. +- max length: the `GraphSampler` class enforces maximum length and maximum number of nodes by terminating the trajectory if either condition is met. +- max size: both `MolBuildingEnvContext` and `FragMolBuildingEnvContext` implement a `max_nodes`/`max_frags` property that is used to mask the `AddNode` action. + +Sequence environments differ somewhat, it's left to the `SeqTransformer` class to mask the stop action using the `min_len` parameter. + +To output fixed-length trajectories it should be sufficient to set `cfg.algo.min_len` and `cfg.algo.max_len` to the same value. Note that in some cases, e.g. when building fragment graphs, the agent may still output trajectories that are shorter than `min_len` by combining two fragments of degree one (leaving no valid action but to stop). diff --git a/src/gflownet/algo/advantage_actor_critic.py b/src/gflownet/algo/advantage_actor_critic.py index 001e19d0..8ed338a5 100644 --- a/src/gflownet/algo/advantage_actor_critic.py +++ b/src/gflownet/algo/advantage_actor_critic.py @@ -115,7 +115,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch: gd.Batch A (CPU) Batch object with relevant attributes added """ - torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]] + torch_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"])] actions = [ self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) ] diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index 6184bdfc..7e45af49 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -97,6 +97,9 @@ class AlgoConfig: The name of the algorithm to use (e.g. "TB") global_batch_size : int The batch size for training + min_len: int + If >0, prevents the agent from using the Stop action before min_len steps (trajectories may still end for + other reasons, but generally setting min_len==max_len should produce fixed length trajectories). max_len : int The maximum length of a trajectory max_nodes : int @@ -124,6 +127,7 @@ class AlgoConfig: method: str = "TB" global_batch_size: int = 64 + min_len: int = 0 max_len: int = 128 max_nodes: int = 128 max_edges: int = 128 diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index 4d694ae2..f6295fb7 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -269,7 +269,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch: gd.Batch A (CPU) Batch object with relevant attributes added """ - torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]] + torch_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"])] actions = [ self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) ] diff --git a/src/gflownet/algo/flow_matching.py b/src/gflownet/algo/flow_matching.py index a1e9a393..763ace20 100644 --- a/src/gflownet/algo/flow_matching.py +++ b/src/gflownet/algo/flow_matching.py @@ -68,15 +68,17 @@ def construct_batch(self, trajs, cond_info, log_rewards): """ if not self.correct_idempotent: # For every s' (i.e. every state except the first of each trajectory), enumerate parents - parents = [[relabel(*i) for i in self.env.parents(i[0])] for tj in trajs for i in tj["traj"][1:]] + parents = [ + ([relabel(*i) for i in self.env.parents(i[0])], t) for tj in trajs for t, i in enumerate(tj["traj"][1:]) + ] # convert parents to Data - parent_graphs = [self.ctx.graph_to_Data(pstate) for parent in parents for pact, pstate in parent] + parent_graphs = [self.ctx.graph_to_Data(pstate, t) for parent, t in parents for pact, pstate in parent] else: # Here we again enumerate parents - states = [i[0] for tj in trajs for i in tj["traj"][1:]] - base_parents = [[relabel(*i) for i in self.env.parents(i)] for i in states] + states = [(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"][1:])] + base_parents = [([relabel(*i) for i in self.env.parents(i)], t) for i, t in states] base_parent_graphs = [ - [self.ctx.graph_to_Data(pstate) for pact, pstate in parent_set] for parent_set in base_parents + [self.ctx.graph_to_Data(pstate, t) for pact, pstate in parent_set] for parent_set, t in base_parents ] parents = [] parent_graphs = [] @@ -103,9 +105,12 @@ def construct_batch(self, trajs, cond_info, log_rewards): parent_actions = [pact for parent in parents for pact, pstate in parent] parent_actionidcs = [self.ctx.GraphAction_to_aidx(gdata, a) for gdata, a in zip(parent_graphs, parent_actions)] # convert state to Data - state_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"][1:]] + state_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"][1:])] terminal_actions = [ - self.ctx.GraphAction_to_aidx(self.ctx.graph_to_Data(tj["traj"][-1][0]), tj["traj"][-1][1]) for tj in trajs + self.ctx.GraphAction_to_aidx( + self.ctx.graph_to_Data(tj["traj"][-1][0], len(tj["traj"]) - 1), tj["traj"][-1][1] + ) + for tj in trajs ] # Create a batch from [*parents, *states]. This order will make it easier when computing the loss diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 7ad4fc0a..af8201f4 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -30,7 +30,16 @@ class GraphSampler: """A helper class to sample from GraphActionCategorical-producing models""" def __init__( - self, ctx, env, max_len, max_nodes, rng, sample_temp=1, correct_idempotent=False, pad_with_terminal_state=False + self, + ctx, + env, + max_len, + max_nodes, + rng, + sample_temp=1, + correct_idempotent=False, + pad_with_terminal_state=False, + # min_len=0, ): """ Parameters @@ -62,6 +71,7 @@ def __init__( self.sanitize_samples = True self.correct_idempotent = correct_idempotent self.pad_with_terminal_state = pad_with_terminal_state + self.consider_masks_complete = ctx.consider_masks_complete if hasattr(ctx, "consider_masks_complete") else False def sample_from_model( self, model: nn.Module, n: int, cond_info: Tensor, dev: torch.device, random_action_prob: float = 0.0 @@ -108,7 +118,7 @@ def not_done(lst): for t in range(self.max_len): # Construct graphs for the trajectories that aren't yet done - torch_graphs = [self.ctx.graph_to_Data(i) for i in not_done(graphs)] + torch_graphs = [self.ctx.graph_to_Data(i, t) for i in not_done(graphs)] not_done_mask = torch.tensor(done, device=dev).logical_not() # Forward pass to get GraphActionCategorical # Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does. @@ -153,7 +163,11 @@ def not_done(lst): # self.env.step can raise AssertionError if the action is illegal gp = self.env.step(graphs[i], graph_actions[j]) assert len(gp.nodes) <= self.max_nodes - except AssertionError: + except AssertionError as e: + if self.consider_masks_complete: + # If masks are considered complete, then we can safely say that we've encountered a bug + # since the agent should only be able to take legal actions (that would not raise an error) + raise e done[i] = True data[i]["is_valid"] = False bck_logprob[i].append(torch.tensor([1.0], device=dev).log()) diff --git a/src/gflownet/algo/soft_q_learning.py b/src/gflownet/algo/soft_q_learning.py index 1e3f1146..4d726c53 100644 --- a/src/gflownet/algo/soft_q_learning.py +++ b/src/gflownet/algo/soft_q_learning.py @@ -111,7 +111,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch: gd.Batch A (CPU) Batch object with relevant attributes added """ - torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]] + torch_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"])] actions = [ self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) ] diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 75e5471f..ba6828c6 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -288,10 +288,12 @@ def construct_batch(self, trajs, cond_info, log_rewards): A (CPU) Batch object with relevant attributes added """ if self.model_is_autoregressive: - torch_graphs = [self.ctx.graph_to_Data(tj["traj"][-1][0]) for tj in trajs] + # Since we're passing the entire sequence to an autoregressive model, it becomes its responsibility to deal + # with `t` (which is always just len(s)). + torch_graphs = [self.ctx.graph_to_Data(tj["traj"][-1][0], t=0) for tj in trajs] actions = [self.ctx.GraphAction_to_aidx(g, i[1]) for g, tj in zip(torch_graphs, trajs) for i in tj["traj"]] else: - torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]] + torch_graphs = [self.ctx.graph_to_Data(i[0], t) for tj in trajs for t, i in enumerate(tj["traj"])] actions = [ self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]]) diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index bab9506b..03b2570e 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -1,6 +1,6 @@ from collections import defaultdict from math import log -from typing import List, Tuple +from typing import List, Optional, Tuple import networkx as nx import numpy as np @@ -24,7 +24,14 @@ class FragMolBuildingEnvContext(GraphBuildingEnvContext): fragments. Masks ensure that the agent can only perform chemically valid attachments. """ - def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tuple[str, List[int]]] = None): + def __init__( + self, + max_frags: int = 9, + num_cond_dim: int = 0, + fragments: Optional[List[Tuple[str, List[int]]]] = None, + min_len: int = 0, + max_len: Optional[int] = None, + ): """Construct a fragment environment Parameters ---------- @@ -37,6 +44,8 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu the fragments of Bengio et al., 2021. """ self.max_frags = max_frags + self.min_len = min_len + self.max_len = max_len if fragments is None: smi, stems = zip(*bengio2021flow.FRAGMENTS) else: @@ -79,6 +88,12 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu self.num_cond_dim = num_cond_dim self.edges_are_duplicated = True self.edges_are_unordered = False + # This flags says that we should be able to trust the masks encoded by graph_to_Data as a ground truth when + # determining if an action is valid or not. In other words, + # - actions produced by this context should always be valid + # - masks produced by this context have the same shape as the logit tensors (e.g. we should be able to use them + # to compute a uniform policy) + self.consider_masks_complete = True self.fail_on_missing_attr = True # Order in which models have to output logits @@ -179,7 +194,7 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int col = 1 return (type_idx, int(row), int(col)) - def graph_to_Data(self, g: Graph) -> gd.Data: + def graph_to_Data(self, g: Graph, t: int = 0) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance Parameters ---------- @@ -260,6 +275,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data: ) add_node_mask = add_node_mask * np.ones((x.shape[0], self.num_new_node_values), np.float32) stop_mask = zeros((1, 1)) if has_unfilled_attach or not len(g) else np.ones((1, 1), np.float32) + stop_mask = stop_mask * ((t >= self.min_len) + (add_node_mask.sum() == 0)).clip(max=1) return gd.Data( **{ diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index f1b12d93..a56127c7 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -892,12 +892,14 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int """ raise NotImplementedError() - def graph_to_Data(self, g: Graph) -> gd.Data: + def graph_to_Data(self, g: Graph, t: int) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance Parameters ---------- g: Graph A graph instance. + t: + The current timestep (may be ignored by some contexts) Returns ------- diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 20c05586..9c6495da 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -32,6 +32,7 @@ def __init__( num_rw_feat=0, max_nodes=None, max_edges=None, + min_time=0, ): """An env context for building molecules atom-by-atom and bond-by-bond. @@ -71,6 +72,7 @@ def __init__( self.num_rw_feat = num_rw_feat self.max_nodes = max_nodes self.max_edges = max_edges + self.min_time = 0 self.default_wildcard_replacement = "C" self.negative_attrs = ["fill_wildcard"] @@ -255,7 +257,7 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int raise ValueError(f"Unknown action type {action.action}") return (type_idx, int(row), int(col)) - def graph_to_Data(self, g: Graph) -> gd.Data: + def graph_to_Data(self, g: Graph, t: int = 0) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance""" x = np.zeros((max(1, len(g.nodes)), self.num_node_dim - self.num_rw_feat), dtype=np.float32) x[0, -1] = len(g.nodes) == 0 @@ -376,8 +378,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data: edge_index=edge_index, edge_attr=edge_attr, non_edge_index=non_edge_index.astype(np.int64).reshape((-1, 2)).T, - stop_mask=np.ones((1, 1), dtype=np.float32) - * (len(g.nodes) > 0), # Can only stop if there's at least a node + stop_mask=np.ones((1, 1)) * (len(g.nodes) > 0) * (t >= self.min_time), # Only stop if there's 1+ nodes add_node_mask=add_node_mask, set_node_attr_mask=set_node_attr_mask, add_edge_mask=np.ones( diff --git a/src/gflownet/envs/seq_building_env.py b/src/gflownet/envs/seq_building_env.py index b8189690..f0b5d57b 100644 --- a/src/gflownet/envs/seq_building_env.py +++ b/src/gflownet/envs/seq_building_env.py @@ -22,6 +22,9 @@ def __init__(self): def __repr__(self): return "".join(map(str, self.seq)) + def __len__(self) -> int: + return len(self.seq) + @property def nodes(self): return self.seq @@ -84,7 +87,7 @@ class AutoregressiveSeqBuildingContext(GraphBuildingEnvContext): This context gets an agent to generate sequences of tokens from left to right, i.e. in an autoregressive fashion. """ - def __init__(self, alphabet: Sequence[str], num_cond_dim=0): + def __init__(self, alphabet: Sequence[str], num_cond_dim=0, min_len=0): self.alphabet = alphabet self.action_type_order = [GraphActionType.Stop, GraphActionType.AddNode] @@ -93,6 +96,7 @@ def __init__(self, alphabet: Sequence[str], num_cond_dim=0): self.pad_token = len(alphabet) + 1 self.num_actions = len(alphabet) + 1 # Alphabet + Stop self.num_cond_dim = num_cond_dim + self.min_len = min_len def aidx_to_GraphAction(self, g: Data, action_idx: Tuple[int, int, int], fwd: bool = True) -> GraphAction: # Since there's only one "object" per timestep to act upon (in graph parlance), the row is always == 0 @@ -115,7 +119,7 @@ def GraphAction_to_aidx(self, g: Data, action: GraphAction) -> Tuple[int, int, i raise ValueError(action) return (type_idx, 0, int(col)) - def graph_to_Data(self, g: Graph): + def graph_to_Data(self, g: Graph, t: int): s: Seq = g # type: ignore return torch.tensor([self.bos_token] + s.seq, dtype=torch.long) diff --git a/src/gflownet/models/config.py b/src/gflownet/models/config.py index e4955d29..8bb62114 100644 --- a/src/gflownet/models/config.py +++ b/src/gflownet/models/config.py @@ -30,10 +30,16 @@ class ModelConfig: The number of layers in the model num_emb : int The number of dimensions of the embedding + dropout : float + The dropout probability in intermediate layers + separate_pB : bool + If true, constructs the backward policy using a separate model (this effectively ~doubles the number of + parameters, all other things being equal) """ num_layers: int = 3 num_emb: int = 128 dropout: float = 0 + do_separate_p_b: bool = False graph_transformer: GraphTransformerConfig = GraphTransformerConfig() seq_transformer: SeqTransformerConfig = SeqTransformerConfig() diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 8c3993f0..e8672f9a 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -172,17 +172,24 @@ def __init__( num_graph_out=1, do_bck=False, ): - """See `GraphTransformer` for argument values""" + """See `GraphTransformer` and its config for argument values + + Parameters + ---------- + env_ctx: GraphBuildingEnvContext + The environment context. This is used to determine the number of actions, and input and output shapes. + cfg: Config + A Config object containing a model configuration. + num_graph_out: int + The number of outputs of the final MLP applied to the graph embeddings. + do_bck: bool + If true, also outputs a backward action distribution. + """ super().__init__() - self.transf = GraphTransformer( - x_dim=env_ctx.num_node_dim, - e_dim=env_ctx.num_edge_dim, - g_dim=env_ctx.num_cond_dim, - num_emb=cfg.model.num_emb, - num_layers=cfg.model.num_layers, - num_heads=cfg.model.graph_transformer.num_heads, - ln_type=cfg.model.graph_transformer.ln_type, - ) + self.trunk = self._make_trunk(env_ctx, cfg) + self.do_separate_p_b = cfg.model.do_separate_p_b + if cfg.model.do_separate_p_b: + self.bck_trunk = self._make_trunk(env_ctx, cfg) num_emb = cfg.model.num_emb num_final = num_emb num_glob_final = num_emb * 2 @@ -223,6 +230,17 @@ def __init__( # TODO: flag for this self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2) + def _make_trunk(self, env_ctx, cfg): + return GraphTransformer( + x_dim=env_ctx.num_node_dim, + e_dim=env_ctx.num_edge_dim, + g_dim=env_ctx.num_cond_dim, + num_emb=cfg.model.num_emb, + num_layers=cfg.model.num_layers, + num_heads=cfg.model.graph_transformer.num_heads, + ln_type=cfg.model.graph_transformer.ln_type, + ) + def _action_type_to_mask(self, t, g): return getattr(g, t.mask_name) if hasattr(g, t.mask_name) else torch.ones((1, 1), device=g.x.device) @@ -244,8 +262,8 @@ def _make_cat(self, g, emb, action_types): types=action_types, ) - def forward(self, g: gd.Batch, cond: torch.Tensor): - node_embeddings, graph_embeddings = self.transf(g, cond) + def _compute_embs(self, model, g, cond): + node_embeddings, graph_embeddings = model(g, cond) # "Non-edges" are edges not currently in the graph that we could add if hasattr(g, "non_edge_index"): ne_row, ne_col = g.non_edge_index @@ -267,16 +285,23 @@ def forward(self, g: gd.Batch, cond: torch.Tensor): else: edge_embeddings = torch.cat([node_embeddings[e_row], node_embeddings[e_col]], 1) - emb = { + return { "graph": graph_embeddings, "node": node_embeddings, "edge": edge_embeddings, "non_edge": non_edge_embeddings, } - graph_out = self.emb2graph_out(graph_embeddings) - fwd_cat = self._make_cat(g, emb, self.action_type_order) + def forward(self, g: gd.Batch, cond: torch.Tensor): + embs = self._compute_embs(self.trunk, g, cond) + if self.do_separate_p_b: + bck_embs = self._compute_embs(self.bck_trunk, g, cond) + else: + bck_embs = embs + + graph_out = self.emb2graph_out(embs["graph"]) + fwd_cat = self._make_cat(g, embs, self.action_type_order) if self.do_bck: - bck_cat = self._make_cat(g, emb, self.bck_action_type_order) + bck_cat = self._make_cat(g, bck_embs, self.bck_action_type_order) return fwd_cat, bck_cat, graph_out return fwd_cat, graph_out diff --git a/src/gflownet/models/seq_transformer.py b/src/gflownet/models/seq_transformer.py index 6916366a..da9e52d9 100644 --- a/src/gflownet/models/seq_transformer.py +++ b/src/gflownet/models/seq_transformer.py @@ -36,9 +36,11 @@ def __init__( env_ctx, cfg: Config, num_state_out=1, + min_len=0, ): super().__init__() self.ctx = env_ctx + self.min_len = min_len self.num_state_out = num_state_out num_hid = cfg.model.num_emb num_outs = env_ctx.num_actions + num_state_out @@ -58,6 +60,8 @@ def __init__( else: self.output = MLPWithDropout(num_hid, num_outs, [2 * num_hid, 2 * num_hid], mc.dropout) self.num_hid = num_hid + # TODO: Merge non-autoregressive implementations of sequence generation + assert not cfg.model.do_separate_pb, "Not implemented for SeqTransformerGFN (since P_B=1 when autoregressive)." def forward(self, xs: SeqBatch, cond, batched=False): """Returns a GraphActionCategorical and a tensor of state predictions. @@ -96,6 +100,12 @@ def forward(self, xs: SeqBatch, cond, batched=False): add_node_logits = out[xs.logit_idx, ns + 1 :] # (proper_time, nout - 1) # `time` above is really max_time, whereas proper_time = sum(len(traj) for traj in xs)) # which is what we need to give to GraphActionCategorical + stop_mask = torch.ones_like(stop_logits) + if self.min_len > 0: + # The +1 accounts for the BOS token + stop_mask = torch.cat([torch.arange(1, i + 1) >= self.min_len for i in xs.lens]) + stop_mask = stop_mask.to(stop_logits.device).float().unsqueeze(-1) + stop_logits = stop_logits * stop_mask - 1000 * (1 - stop_mask) else: # The default num_graphs is computed for the batched case, so we need to fix it here so that # GraphActionCategorical knows how many "graphs" (sequence inputs) there are @@ -104,6 +114,11 @@ def forward(self, xs: SeqBatch, cond, batched=False): state_preds = out[:, 0:ns] stop_logits = out[:, ns : ns + 1] add_node_logits = out[:, ns + 1 :] + stop_mask = torch.ones_like(stop_logits) + if self.min_len > 0: + # The +1 accounts for the BOS token + stop_mask = stop_mask * (xs.lens >= self.min_len + 1).unsqueeze(-1).float() + stop_logits = stop_logits * stop_mask - 1000 * (1 - stop_mask) return ( GraphActionCategorical( @@ -111,6 +126,7 @@ def forward(self, xs: SeqBatch, cond, batched=False): logits=[stop_logits, add_node_logits], keys=[None, None], types=self.ctx.action_type_order, + masks=[stop_mask, torch.ones_like(add_node_logits)], slice_dict={}, ), state_preds, diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 91d65818..db531074 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -192,6 +192,7 @@ def setup_env_context(self): max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.task.num_cond_dim, fragments=bengio2021flow.FRAGMENTS_18 if self.cfg.task.seh.reduced_frag else bengio2021flow.FRAGMENTS, + min_len=self.cfg.algo.min_len, ) def setup(self): diff --git a/src/gflownet/tasks/toy_seq.py b/src/gflownet/tasks/toy_seq.py index 7fe0f24b..6d1df1f0 100644 --- a/src/gflownet/tasks/toy_seq.py +++ b/src/gflownet/tasks/toy_seq.py @@ -63,6 +63,7 @@ def set_default_hps(self, cfg: Config): cfg.algo.method = "TB" cfg.algo.max_nodes = 10 + cfg.algo.min_len = 10 cfg.algo.max_len = 10 cfg.algo.sampling_tau = 0.9 cfg.algo.illegal_action_logreward = -75 @@ -79,6 +80,7 @@ def setup_model(self): self.model = SeqTransformerGFN( self.ctx, self.cfg, + min_len=self.cfg.algo.min_len, ) def setup_task(self): @@ -93,6 +95,7 @@ def setup_env_context(self): self.ctx = AutoregressiveSeqBuildingContext( "abc", self.task.num_cond_dim, + self.cfg.algo.min_len, ) def setup_algo(self): diff --git a/tests/test_envs.py b/tests/test_envs.py index 204a17cb..b9785b3b 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -49,7 +49,7 @@ def g2h(g): def expand(s, idx): # Recursively expand all children of s - gd = ctx.graph_to_Data(s) + gd = ctx.graph_to_Data(s, t=0) masks = [getattr(gd, gat.mask_name) for gat in ctx.action_type_order] for at, mask in enumerate(masks): if at == 0: # Ignore Stop action @@ -116,7 +116,7 @@ def _test_backwards_mask_equivalence(two_node_states, ctx): g = two_node_states[i] n = env.count_backward_transitions(g, check_idempotent=False) nm = 0 - gd = ctx.graph_to_Data(g) + gd = ctx.graph_to_Data(g, t=0) for u, k in enumerate(ctx.bck_action_type_order): m = getattr(gd, k.mask_name) nm += m.sum() @@ -138,7 +138,7 @@ def _test_backwards_mask_equivalence_ipa(two_node_states, ctx): for i in range(1, len(two_node_states)): g = two_node_states[i] n = env.count_backward_transitions(g, check_idempotent=True) - gd = ctx.graph_to_Data(g) + gd = ctx.graph_to_Data(g, t=0) # To check that we're computing masks correctly, we need to check that there is the same # number of idempotent action classes, i.e. groups of actions that lead to the same parent. equivalence_classes = []