diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1cf64a4a2ebf..4a39530add05 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -4,6 +4,7 @@ is_onnx_available, is_scipy_available, is_torch_available, + is_torch_geometric_available, is_transformers_available, is_unidecode_available, ) @@ -83,3 +84,8 @@ from .pipelines import FlaxStableDiffusionPipeline else: from .utils.dummy_flax_and_transformers_objects import * # noqa F403 + +if is_torch_geometric_available(): + from .models import MoleculeGNN +else: + from .utils.dummy_torch_geometric_objects import * # noqa F403 diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 1242ad6fca7f..78eb92dba12d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils import is_flax_available, is_torch_available +from ..utils import is_flax_available, is_torch_available, is_torch_geometric_available if is_torch_available(): @@ -23,3 +23,6 @@ if is_flax_available(): from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL + +if is_torch_geometric_available(): + from .molecule_gnn import MoleculeGNN diff --git a/src/diffusers/models/molecule_gnn.py b/src/diffusers/models/molecule_gnn.py new file mode 100644 index 000000000000..be2f035ace09 --- /dev/null +++ b/src/diffusers/models/molecule_gnn.py @@ -0,0 +1,679 @@ +# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff +# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models +from dataclasses import dataclass +from typing import Callable, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import Embedding, Linear, Module, ModuleList, Sequential + +from torch_geometric.nn import MessagePassing, radius, radius_graph +from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size +from torch_geometric.utils import dense_to_sparse, to_dense_adj +from torch_scatter import scatter_add +from torch_sparse import SparseTensor, coalesce + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..utils import BaseOutput + + +@dataclass +class MoleculeGNNOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states output. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class MultiLayerPerceptron(nn.Module): + """ + Multi-layer Perceptron. Note there is no activation or dropout in the last layer. + + Args: + input_dim (int): input dimension + hidden_dim (list of int): hidden dimensions + activation (str or function, optional): activation function + dropout (float, optional): dropout rate + """ + + def __init__(self, input_dim, hidden_dims, activation="relu", dropout=0): + super(MultiLayerPerceptron, self).__init__() + + self.dims = [input_dim] + hidden_dims + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + print(f"Warning, activation passed {activation} is not string and ignored") + self.activation = None + if dropout > 0: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = None + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1])) + + def forward(self, x): + """""" + for i, layer in enumerate(self.layers): + x = layer(x) + if i < len(self.layers) - 1: + if self.activation: + x = self.activation(x) + if self.dropout: + x = self.dropout(x) + return x + + +class ShiftedSoftplus(torch.nn.Module): + def __init__(self): + super(ShiftedSoftplus, self).__init__() + self.shift = torch.log(torch.tensor(2.0)).item() + + def forward(self, x): + return F.softplus(x) - self.shift + + +class CFConv(MessagePassing): + def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth): + super(CFConv, self).__init__(aggr="add") + self.lin1 = Linear(in_channels, num_filters, bias=False) + self.lin2 = Linear(num_filters, out_channels) + self.nn = mlp + self.cutoff = cutoff + self.smooth = smooth + + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.lin1.weight) + torch.nn.init.xavier_uniform_(self.lin2.weight) + self.lin2.bias.data.fill_(0) + + def forward(self, x, edge_index, edge_length, edge_attr): + if self.smooth: + C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0) + C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff + else: + C = (edge_length <= self.cutoff).float() + W = self.nn(edge_attr) * C.view(-1, 1) + + x = self.lin1(x) + x = self.propagate(edge_index, x=x, W=W) + x = self.lin2(x) + return x + + def message(self, x_j: torch.Tensor, W) -> torch.Tensor: + return x_j * W + + +class InteractionBlock(torch.nn.Module): + def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth): + super(InteractionBlock, self).__init__() + mlp = Sequential( + Linear(num_gaussians, num_filters), + ShiftedSoftplus(), + Linear(num_filters, num_filters), + ) + self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth) + self.act = ShiftedSoftplus() + self.lin = Linear(hidden_channels, hidden_channels) + + def forward(self, x, edge_index, edge_length, edge_attr): + x = self.conv(x, edge_index, edge_length, edge_attr) + x = self.act(x) + x = self.lin(x) + return x + + +class SchNetEncoder(Module): + def __init__( + self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False + ): + super().__init__() + + self.hidden_channels = hidden_channels + self.num_filters = num_filters + self.num_interactions = num_interactions + self.cutoff = cutoff + + self.embedding = Embedding(100, hidden_channels, max_norm=10.0) + + self.interactions = ModuleList() + for _ in range(num_interactions): + block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth) + self.interactions.append(block) + + def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True): + if embed_node: + assert z.dim() == 1 and z.dtype == torch.long + h = self.embedding(z) + else: + h = z + for interaction in self.interactions: + h = h + interaction(h, edge_index, edge_length, edge_attr) + + return h + + +class GINEConv(MessagePassing): + """ + Custom class of the graph isomorphism operator from the "How Powerful are Graph Neural Networks? + https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation. + """ + + def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation="softplus", **kwargs): + super(GINEConv, self).__init__(aggr="add", **kwargs) + self.nn = mlp + self.initial_eps = eps + + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = None + + if train_eps: + self.eps = torch.nn.Parameter(torch.Tensor([eps])) + else: + self.register_buffer("eps", torch.Tensor([eps])) + + def forward( + self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None + ) -> torch.Tensor: + """""" + if isinstance(x, torch.Tensor): + x: OptPairTensor = (x, x) + + # Node and edge feature dimensionalites need to match. + if isinstance(edge_index, torch.Tensor): + assert edge_attr is not None + assert x[0].size(-1) == edge_attr.size(-1) + elif isinstance(edge_index, SparseTensor): + assert x[0].size(-1) == edge_index.size(-1) + + # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) + out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) + + x_r = x[1] + if x_r is not None: + out += (1 + self.eps) * x_r + + return self.nn(out) + + def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor: + if self.activation: + return self.activation(x_j + edge_attr) + else: + return x_j + edge_attr + + def __repr__(self): + return "{}(nn={})".format(self.__class__.__name__, self.nn) + + +class GINEncoder(torch.nn.Module): + def __init__(self, hidden_dim, num_convs=3, activation="relu", short_cut=True, concat_hidden=False): + super().__init__() + + self.hidden_dim = hidden_dim + self.num_convs = num_convs + self.short_cut = short_cut + self.concat_hidden = concat_hidden + self.node_emb = nn.Embedding(100, hidden_dim) + + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = None + + self.convs = nn.ModuleList() + for i in range(self.num_convs): + self.convs.append( + GINEConv( + MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation), + activation=activation, + ) + ) + + def forward(self, z, edge_index, edge_attr): + """ + Input: + data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node, + hidden) edge_attr: edge feature tensor with shape (num_edge, hidden) + Output: + node_feature: graph feature + """ + + node_attr = self.node_emb(z) # (num_node, hidden) + + hiddens = [] + conv_input = node_attr # (num_node, hidden) + + for conv_idx, conv in enumerate(self.convs): + hidden = conv(conv_input, edge_index, edge_attr) + if conv_idx < len(self.convs) - 1 and self.activation is not None: + hidden = self.activation(hidden) + assert hidden.shape == conv_input.shape + if self.short_cut and hidden.shape == conv_input.shape: + hidden += conv_input + + hiddens.append(hidden) + conv_input = hidden + + if self.concat_hidden: + node_feature = torch.cat(hiddens, dim=-1) + else: + node_feature = hiddens[-1] + + return node_feature + + +class MLPEdgeEncoder(Module): + def __init__(self, hidden_dim=100, activation="relu"): + super().__init__() + self.hidden_dim = hidden_dim + self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim) + self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation) + + @property + def out_channels(self): + return self.hidden_dim + + def forward(self, edge_length, edge_type): + """ + Input: + edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,) + Returns: + edge_attr: The representation of edges. (E, 2 * num_gaussians) + """ + d_emb = self.mlp(edge_length) # (num_edge, hidden_dim) + edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim) + return d_emb * edge_attr # (num_edge, hidden) + + +def assemble_atom_pair_feature(node_attr, edge_index, edge_attr): + h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]] + h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H) + return h_pair + + +def _extend_graph_order(num_nodes, edge_index, edge_type, order=3): + """ + Args: + num_nodes: Number of atoms. + edge_index: Bond indices of the original graph. + edge_type: Bond types of the original graph. + order: Extension order. + Returns: + new_edge_index: Extended edge indices. new_edge_type: Extended edge types. + """ + + def binarize(x): + return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) + + def get_higher_order_adj_matrix(adj, order): + """ + Args: + adj: (N, N) + type_mat: (N, N) + Returns: + Following attributes will be updated: + - edge_index + - edge_type + Following attributes will be added to the data object: + - bond_edge_index: Original edge_index. + """ + adj_mats = [ + torch.eye(adj.size(0), dtype=torch.long, device=adj.device), + binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)), + ] + + for i in range(2, order + 1): + adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1])) + order_mat = torch.zeros_like(adj) + + for i in range(1, order + 1): + order_mat += (adj_mats[i] - adj_mats[i - 1]) * i + + return order_mat + + num_types = 22 + # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} + # from rdkit.Chem.rdchem import BondType as BT + N = num_nodes + adj = to_dense_adj(edge_index).squeeze(0) + adj_order = get_higher_order_adj_matrix(adj, order) # (N, N) + + type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N) + type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order)) + assert (type_mat * type_highorder == 0).all() + type_new = type_mat + type_highorder + + new_edge_index, new_edge_type = dense_to_sparse(type_new) + _, edge_order = dense_to_sparse(adj_order) + + # data.bond_edge_index = data.edge_index # Save original edges + new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data + + return new_edge_index, new_edge_type + + +def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None): + assert edge_type.dim() == 1 + N = pos.size(0) + + bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N])) + + if is_sidechain is None: + rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r) + else: + # fetch sidechain and its batch index + is_sidechain = is_sidechain.bool() + dummy_index = torch.arange(pos.size(0), device=pos.device) + sidechain_pos = pos[is_sidechain] + sidechain_index = dummy_index[is_sidechain] + sidechain_batch = batch[is_sidechain] + + assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch) + r_edge_index_x = assign_index[1] + r_edge_index_y = assign_index[0] + r_edge_index_y = sidechain_index[r_edge_index_y] + + rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E) + rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E) + rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E) + # delete self loop + rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])] + + rgraph_adj = torch.sparse.LongTensor( + rgraph_edge_index, + torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number, + torch.Size([N, N]), + ) + + composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T) + + new_edge_index = composed_adj.indices() + new_edge_type = composed_adj.values().long() + + return new_edge_index, new_edge_type + + +def extend_graph_order_radius( + num_nodes, + pos, + edge_index, + edge_type, + batch, + order=3, + cutoff=10.0, + extend_order=True, + extend_radius=True, + is_sidechain=None, +): + if extend_order: + edge_index, edge_type = _extend_graph_order( + num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order + ) + + if extend_radius: + edge_index, edge_type = _extend_to_radius_graph( + pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain + ) + + return edge_index, edge_type + + +def get_distance(pos, edge_index): + return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1) + + +def graph_field_network(score_d, pos, edge_index, edge_length): + """ + Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations + 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf + """ + N = pos.size(0) + dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3) + score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add( + -dd_dr * score_d, edge_index[1], dim=0, dim_size=N + ) # (N, 3) + return score_pos + + +class MoleculeGNN(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + hidden_dim=128, + num_convs=6, + num_convs_local=4, + cutoff=10.0, + mlp_act="relu", + edge_order=3, + edge_encoder="mlp", + smooth_conv=True, + ): + super().__init__() + self.cutoff = cutoff + self.edge_encoder = edge_encoder + self.edge_order = edge_order + + """ + edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done + in SchNetEncoder + """ + self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config) + self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config) + + """ + The graph neural network that extracts node-wise features. + """ + self.encoder_global = SchNetEncoder( + hidden_channels=hidden_dim, + num_filters=hidden_dim, + num_interactions=num_convs, + edge_channels=self.edge_encoder_global.out_channels, + cutoff=cutoff, + smooth=smooth_conv, + ) + self.encoder_local = GINEncoder( + hidden_dim=hidden_dim, + num_convs=num_convs_local, + ) + + """ + `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs + gradients w.r.t. edge_length (out_dim = 1). + """ + self.grad_global_dist_mlp = MultiLayerPerceptron( + 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act + ) + + self.grad_local_dist_mlp = MultiLayerPerceptron( + 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act + ) + + """ + Incorporate parameters together + """ + self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp]) + self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp]) + + def _forward( + self, + atom_type, + pos, + bond_index, + bond_type, + batch, + time_step, # NOTE, model trained without timestep performed best + edge_index=None, + edge_type=None, + edge_length=None, + return_edges=False, + extend_order=True, + extend_radius=True, + is_sidechain=None, + ): + """ + Args: + atom_type: Types of atoms, (N, ). + bond_index: Indices of bonds (not extended, not radius-graph), (2, E). + bond_type: Bond types, (E, ). + batch: Node index to graph index, (N, ). + """ + N = atom_type.size(0) + if edge_index is None or edge_type is None or edge_length is None: + edge_index, edge_type = extend_graph_order_radius( + num_nodes=N, + pos=pos, + edge_index=bond_index, + edge_type=bond_type, + batch=batch, + order=self.edge_order, + cutoff=self.cutoff, + extend_order=extend_order, + extend_radius=extend_radius, + is_sidechain=is_sidechain, + ) + edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1) + local_edge_mask = is_local_edge(edge_type) # (E, ) + + # with the parameterization of NCSNv2 + # DDPM loss implicit handle the noise variance scale conditioning + sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1) + + # Encoding global + edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges + + # Global + node_attr_global = self.encoder_global( + z=atom_type, + edge_index=edge_index, + edge_length=edge_length, + edge_attr=edge_attr_global, + ) + # Assemble pairwise features + h_pair_global = assemble_atom_pair_feature( + node_attr=node_attr_global, + edge_index=edge_index, + edge_attr=edge_attr_global, + ) # (E_global, 2H) + # Invariant features of edges (radius graph, global) + edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1) + + # Encoding local + edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges + # edge_attr += temb_edge + + # Local + node_attr_local = self.encoder_local( + z=atom_type, + edge_index=edge_index[:, local_edge_mask], + edge_attr=edge_attr_local[local_edge_mask], + ) + # Assemble pairwise features + h_pair_local = assemble_atom_pair_feature( + node_attr=node_attr_local, + edge_index=edge_index[:, local_edge_mask], + edge_attr=edge_attr_local[local_edge_mask], + ) # (E_local, 2H) + + # Invariant features of edges (bond graph, local) + if isinstance(sigma_edge, torch.Tensor): + edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * ( + 1.0 / sigma_edge[local_edge_mask] + ) # (E_local, 1) + else: + edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1) + + if return_edges: + return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask + else: + return edge_inv_global, edge_inv_local + + def forward( + self, + sample, + timestep: Union[torch.Tensor, float, int], + return_dict: bool = True, + sigma=1.0, + global_start_sigma=0.5, + w_global=1.0, + extend_order=False, + extend_radius=True, + clip_local=None, + clip_global=1000.0, + ) -> Union[MoleculeGNNOutput, Tuple]: + r""" + Args: + sample: packed torch geometric object + timestep (`torch.FloatTensor` or `float` or `int): TODO verify type and shape (batch) timesteps + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple. + + Returns: + [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + + # unpack sample + atom_type = sample.atom_type + bond_index = sample.edge_index + bond_type = sample.edge_type + num_graphs = sample.num_graphs + pos = sample.pos + + timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device) + + edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward( + atom_type=atom_type, + pos=sample.pos, + bond_index=bond_index, + bond_type=bond_type, + batch=sample.batch, + time_step=timesteps, + return_edges=True, + extend_order=extend_order, + extend_radius=extend_radius, + ) # (E_global, 1), (E_local, 1) + + # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff + node_eq_local = graph_field_network( + edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask] + ) + if clip_local is not None: + node_eq_local = clip_norm(node_eq_local, limit=clip_local) + + # Global + if sigma < global_start_sigma: + edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float()) + node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length) + node_eq_global = clip_norm(node_eq_global, limit=clip_global) + else: + node_eq_global = 0 + + # Sum + eps_pos = node_eq_local + node_eq_global * w_global + + if not return_dict: + return (-eps_pos,) + + return MoleculeGNNOutput(sample=torch.FloatTensor(-eps_pos)) + + +def clip_norm(vec, limit, p=2): + norm = torch.norm(vec, dim=-1, p=2, keepdim=True) + denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm)) + return vec * denom + + +def is_local_edge(edge_type): + return edge_type > 0 diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 4d4e986a76ea..0a4666c2a795 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -133,6 +133,13 @@ def __init__( elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "sigmoid": + + def sigmoid(x): + return 1 / (np.exp(-x) + 1) + + betas = np.linspace(-6, 6, num_train_timesteps) + self.betas = sigmoid(betas) * (beta_end - beta_start) + beta_start else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index c1285bb8c23d..9f38bc215ff6 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,6 +30,7 @@ is_scipy_available, is_tf_available, is_torch_available, + is_torch_geometric_available, is_transformers_available, is_unidecode_available, requires_backends, diff --git a/src/diffusers/utils/dummy_torch_geometric_objects.py b/src/diffusers/utils/dummy_torch_geometric_objects.py new file mode 100644 index 000000000000..adb24fe6730c --- /dev/null +++ b/src/diffusers/utils/dummy_torch_geometric_objects.py @@ -0,0 +1,10 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa +from ..utils import DummyObject, requires_backends + + +class MoleculeGNN(metaclass=DummyObject): + _backends = ["torch_geometric"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch_geometric"]) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index de344d074da0..7c651ec8b39e 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -159,6 +159,20 @@ except importlib_metadata.PackageNotFoundError: _scipy_available = False +_torch_scatter_available = importlib.util.find_spec("torch_scatter") is not None +try: + _torch_scatter_version = importlib_metadata.version("torch_scatter") + logger.debug(f"Successfully imported torch_scatter version {_torch_scatter_version}") +except importlib_metadata.PackageNotFoundError: + _torch_scatter_available = False + +_torch_geometric_available = importlib.util.find_spec("torch_geometric") is not None +try: + _torch_geometric_version = importlib_metadata.version("torch_geometric") + logger.debug(f"Successfully imported torch_geometric version {_torch_geometric_version}") +except importlib_metadata.PackageNotFoundError: + _torch_geometric_available = False + def is_torch_available(): return _torch_available @@ -196,6 +210,18 @@ def is_scipy_available(): return _scipy_available +def is_torch_scatter_available(): + return _torch_scatter_available + + +def is_torch_geometric_available(): + # the model source of the Molecule Generation GNN requires a specific torch geometric version + # for more info, see the original repo https://github.com/MinkaiXu/GeoDiff or our colab in readme + if not _torch_geometric_available: + return False + return _torch_geometric_version == "1.7.2" + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -244,6 +270,12 @@ def is_scipy_available(): Unidecode` """ +# docstyle-ignore +TORCH_GEOMETRIC_IMPORT_ERROR = """ +{0} requires version 1.7.2 of torch_geometric but it was not found in your environment. You can install it with conda: +`conda install -c rusty1s pytorch-geometric=1.7.2`, given pytorch 1.8 +""" + BACKENDS_MAPPING = OrderedDict( [ diff --git a/tests/test_models_gnn.py b/tests/test_models_gnn.py new file mode 100644 index 000000000000..a0d87b835b44 --- /dev/null +++ b/tests/test_models_gnn.py @@ -0,0 +1,194 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch + +from diffusers.utils import is_torch_geometric_available +from diffusers.utils.testing_utils import torch_device + +from .test_modeling_common import ModelTesterMixin + + +if is_torch_geometric_available(): + from diffusers import MoleculeGNN +else: + from diffusers.utils.dummy_torch_geometric_objects import * # noqa F403 + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class MoleculeGNNTests(ModelTesterMixin, unittest.TestCase): + model_class = MoleculeGNN + + @property + def dummy_input(self): + batch_size = 2 + time_step = 10 + + class GeoDiffData: + # constants corresponding to a molecule + num_nodes = 6 + num_edges = 10 + num_graphs = 1 + + # sampling + torch.Generator(device=torch_device) + torch.manual_seed(3) + + # molecule components / properties + atom_type = torch.randint(0, 6, (num_nodes * batch_size,)).to(torch_device) + edge_index = torch.randint( + 0, + num_edges, + ( + 2, + num_edges * batch_size, + ), + ).to(torch_device) + edge_type = torch.randint(0, 5, (num_edges * batch_size,)).to(torch_device) + pos = 0.001 * torch.randn(num_nodes * batch_size, 3).to(torch_device) + batch = torch.tensor([*range(batch_size)]).repeat_interleave(num_nodes) + nx = batch_size + + torch.manual_seed(0) + noise = GeoDiffData() + + return {"sample": noise, "timestep": time_step, "sigma": 1.0} + + @property + def output_shape(self): + # subset of shapes for dummy input + class GeoDiffShapes: + shape_0 = torch.Size([1305, 1]) + shape_1 = torch.Size([92, 1]) + + return GeoDiffShapes() + + # training not implemented for this model yet + def test_training(self): + pass + + def test_ema_training(self): + pass + + def test_determinism(self): + # TODO + pass + + def test_output(self): + def test_output(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output["sample"] + + self.assertIsNotNone(output) + shapes = self.output_shapes() + self.assertEqual(output[0].shape, shapes.shape_0, "Input and output shapes do not match") + self.assertEqual(output[1].shape, shapes.shape_1, "Input and output shapes do not match") + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "hidden_dim": 128, + "num_convs": 6, + "num_convs_local": 4, + "cutoff": 10.0, + "mlp_act": "relu", + "edge_order": 3, + "edge_encoder": "mlp", + "smooth_conv": True, + } + + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = MoleculeGNN.from_pretrained("fusing/gfn-molecule-gen-drugs", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = MoleculeGNN.from_pretrained("fusing/gfn-molecule-gen-drugs") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + input = self.dummy_input + sample, time_step, sigma = input["sample"], input["timestep"], input["sigma"] + with torch.no_grad(): + output = model(sample, time_step, sigma=sigma)["sample"] + + output_slice = output[:3][:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-3.7335, -7.4622, -29.5600, 16.9646, -11.2205, -32.5315, 1.2303, 4.2985, 8.8828]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) + + def test_model_from_config(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + # test if the model can be loaded from the config + # and has all the expected shape + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_config(tmpdirname) + new_model = self.model_class.from_config(tmpdirname) + new_model.to(torch_device) + new_model.eval() + + # check if all paramters shape are the same + for param_name in model.state_dict().keys(): + param_1 = model.state_dict()[param_name] + param_2 = new_model.state_dict()[param_name] + self.assertEqual(param_1.shape, param_2.shape) + + with torch.no_grad(): + output_1 = model(**inputs_dict) + + if isinstance(output_1, dict): + output_1 = output_1["sample"] + + output_2 = new_model(**inputs_dict) + + if isinstance(output_2, dict): + output_2 = output_2["sample"] + + self.assertEqual(output_1[0].shape, output_2[0].shape) + self.assertEqual(output_1[1].shape, output_2[1].shape) + + def test_forward_with_norm_groups(self): + # not implemented for this model + pass