diff --git a/mindnlp/transformers/models/__init__.py b/mindnlp/transformers/models/__init__.py index 722aa0f7d..ff8a93c76 100644 --- a/mindnlp/transformers/models/__init__.py +++ b/mindnlp/transformers/models/__init__.py @@ -135,6 +135,7 @@ luke, lxmert, mamba, + mamba2, marian, markuplm, m2m_100, @@ -381,6 +382,7 @@ from .lxmert import * from .m2m_100 import * from .mamba import * +from .mamba2 import * from .marian import * from .markuplm import * from .maskformer import * @@ -626,6 +628,7 @@ __all__.extend(lxmert.__all__) __all__.extend(m2m_100.__all__) __all__.extend(mamba.__all__) +__all__.extend(mamba2.__all__) __all__.extend(marian.__all__) __all__.extend(markuplm.__all__) __all__.extend(maskformer.__all__) diff --git a/mindnlp/transformers/models/auto/configuration_auto.py b/mindnlp/transformers/models/auto/configuration_auto.py index 73d5851f2..96ae6008e 100644 --- a/mindnlp/transformers/models/auto/configuration_auto.py +++ b/mindnlp/transformers/models/auto/configuration_auto.py @@ -135,6 +135,7 @@ ("lxmert", "LxmertConfig"), ("m2m_100", "M2M100Config"), ("mamba", "MambaConfig"), + ("mamba2", "Mamba2Config"), ("marian", "MarianConfig"), ('markuplm', "MarkupLMConfig"), ("mask2former", "Mask2FormerConfig"), @@ -353,6 +354,7 @@ ("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("mamba", "MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("mamba2", "MAMBA2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("marian", "MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("markuplm", "MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("mask2former", "MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -608,6 +610,7 @@ ("lxmert", "LXMERT"), ("m2m_100", "M2M100"), ("mamba", "Mamba"), + ("mamba2", "Mamba2"), ("marian", "Marian"), ("markuplm", "MarkupLM"), ("mask2former", "Mask2Former"), diff --git a/mindnlp/transformers/models/auto/modeling_auto.py b/mindnlp/transformers/models/auto/modeling_auto.py index 026ea2a43..3a8d5cb33 100644 --- a/mindnlp/transformers/models/auto/modeling_auto.py +++ b/mindnlp/transformers/models/auto/modeling_auto.py @@ -151,6 +151,7 @@ ("lxmert", "LxmertModel"), ("m2m_100", "M2M100Model"), ("mamba", "MambaModel"), + ("mamba2", "Mamba2Model"), ("marian", "MarianModel"), ("markuplm", "MarkupLMModel"), ("mask2former", "Mask2FormerModel"), @@ -318,6 +319,7 @@ ("luke", "LukeForMaskedLM"), ("lxmert", "LxmertForPreTraining"), ("mamba", "MambaForCausalLM"), + ("mamba2", "Mamba2ForCausalLM"), ("mega", "MegaForMaskedLM"), ("megatron-bert", "MegatronBertForPreTraining"), ('minicpm', 'MiniCPMForCausalLM'), @@ -405,6 +407,7 @@ ("luke", "LukeForMaskedLM"), ("m2m_100", "M2M100ForConditionalGeneration"), ("mamba", "MambaForCausalLM"), + ("mamba2", "Mamba2ForCausalLM"), ("marian", "MarianMTModel"), ("mega", "MegaForMaskedLM"), ("megatron-bert", "MegatronBertForCausalLM"), @@ -491,6 +494,7 @@ ("jetmoe", "JetMoeForCausalLM"), ("llama", "LlamaForCausalLM"), ("mamba", "MambaForCausalLM"), + ("mamba2", "Mamba2ForCausalLM"), ("marian", "MarianForCausalLM"), ("mbart", "MBartForCausalLM"), ("mega", "MegaForCausalLM"), diff --git a/mindnlp/transformers/models/auto/tokenization_auto.py b/mindnlp/transformers/models/auto/tokenization_auto.py index 1ad0adbe4..054141fad 100644 --- a/mindnlp/transformers/models/auto/tokenization_auto.py +++ b/mindnlp/transformers/models/auto/tokenization_auto.py @@ -269,6 +269,7 @@ ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)), ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)), ("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)), ( "mbart", diff --git a/mindnlp/transformers/models/mamba2/__init__.py b/mindnlp/transformers/models/mamba2/__init__.py new file mode 100644 index 000000000..74e92a14b --- /dev/null +++ b/mindnlp/transformers/models/mamba2/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Mamba2 Model. +""" +from . import modeling_mamba2, configuration_mamba2 +from .modeling_mamba2 import * +from .configuration_mamba2 import * + +__all__ = [] +__all__.extend(modeling_mamba2.__all__) +__all__.extend(configuration_mamba2.__all__) diff --git a/mindnlp/transformers/models/mamba2/configuration_mamba2.py b/mindnlp/transformers/models/mamba2/configuration_mamba2.py new file mode 100644 index 000000000..c884be60e --- /dev/null +++ b/mindnlp/transformers/models/mamba2/configuration_mamba2.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MAMBA2 configuration""" + +import math + +from mindnlp.utils import logging +from ...configuration_utils import PretrainedConfig + +logger = logging.get_logger(__name__) + +class Mamba2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Mamba2Model`]. It is used to instantiate a MAMBA2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the MAMBA2 + [state-spaces/mamba2-2.8b](https://huggingface.co/state-spaces/mamba2-2.8b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + num_heads (`int`, *optional*, defaults to 128): + Number of heads for the evolution matrices of mamba 2. + head_dim (`int`, *optional*, defaults to 64): + Dimension of each head. + vocab_size (`int`, *optional*, defaults to 32768): + Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Mamba2Model`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimensionality of the embeddings and hidden states. + state_size (`int`, *optional*, defaults to 128): shape of the state space latents. + num_hidden_layers (`int`, *optional*, defaults to 64): + Number of hidden layers in the model. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0): + The id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the end of sentence token in the vocabulary. + expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. + conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel. + n_groups (`int`, *optional*, defaults to 8): + Number of groups for the evolution matrices of mamba 2. + use_bias (`bool`, *optional*, defaults to `False`): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block + use_conv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the convolution layer of the mixer block. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.1): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + residual_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model + time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj.bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj.bias`. + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`): + Accepted range of time step values. + rescale_prenorm_residual (`bool`, *optional*, defaults to `False`): + Whether or not to rescale `out_proj` weights when initializing. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the cache should be used. + rms_norm (`bool`, *optional*, defaults to `True`): + Whether to use RMS norm or not. + chunk_size (`int`, *optional*, defaults to 256): + Size of the chunks that will comprise the sequence. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie word embeddings or not. + + + Example: + + ```python + >>> from transformers import Mamba2Config, Mamba2Model + + >>> # Initializing a Mamba2 configuration + >>> configuration = Mamba2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = Mamba2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mamba2" + + def __init__( + self, + num_heads=128, + head_dim=64, + vocab_size=32768, + hidden_size=4096, + state_size=128, + num_hidden_layers=64, + layer_norm_epsilon=1e-5, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + expand=2, + conv_kernel=4, + n_groups=8, + use_bias=False, + use_conv_bias=True, + hidden_act="silu", + initializer_range=0.1, + residual_in_fp32=True, + time_step_rank="auto", + time_step_min=0.001, + time_step_max=0.1, + time_step_floor=1e-4, + time_step_limit=(0.0, float("inf")), + rescale_prenorm_residual=False, + use_cache=True, + rms_norm=True, + chunk_size=256, + tie_word_embeddings=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.state_size = state_size + self.num_hidden_layers = num_hidden_layers + self.layer_norm_epsilon = layer_norm_epsilon + self.conv_kernel = conv_kernel + self.expand = expand + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.use_bias = use_bias + self.use_conv_bias = use_conv_bias + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_floor = time_step_floor + self.rescale_prenorm_residual = rescale_prenorm_residual + self.residual_in_fp32 = residual_in_fp32 + self.use_cache = use_cache + self.n_groups = n_groups + self.num_heads = num_heads + self.head_dim = head_dim + self.rms_norm = rms_norm + self.state_size = state_size + self.chunk_size = chunk_size + self.time_step_limit = time_step_limit + self.tie_word_embeddings = tie_word_embeddings + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Mamba2Config"] diff --git a/mindnlp/transformers/models/mamba2/modeling_mamba2.py b/mindnlp/transformers/models/mamba2/modeling_mamba2.py new file mode 100644 index 000000000..d6044b53c --- /dev/null +++ b/mindnlp/transformers/models/mamba2/modeling_mamba2.py @@ -0,0 +1,917 @@ +# coding=utf-8 +# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MindSpore MAMBA2 model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import mindspore +from mindnlp.core import nn, ops, no_grad +from mindnlp.core.nn import CrossEntropyLoss + +from ....common.activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_utils import PreTrainedModel + +from ....utils import ( + ModelOutput, + logging, +) + +from .configuration_mamba2 import Mamba2Config + + +logger = logging.get_logger(__name__) + + + +_CHECKPOINT_FOR_DOC = "mistralai/mamba-codestral-7B-v0.1" +_CONFIG_FOR_DOC = "Mamba2Config" + + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: mindspore.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len axis (axis=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len axis (axis=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.shape[-1] + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor.unsqueeze(-1) + target_shape = tuple(input_tensor.shape[:-1] + (chunk_size,)) + input_tensor = input_tensor.broadcast_to(target_shape) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = ops.tril(ops.ones(chunk_size, chunk_size, dtype=mindspore.bool_), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, mindspore.Tensor(0, dtype=input_tensor.dtype)) + # 3. compute actual cumsum + tensor_segsum = ops.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = ops.tril(ops.ones(chunk_size, chunk_size, dtype=mindspore.bool_), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, mindspore.Tensor(float('-inf'), dtype=tensor_segsum.dtype)) + return tensor_segsum + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + +# Simple roll function for CPU and NPU +if mindspore.context.get_context("device_target") == "GPU": + from mindspore.ops import roll +else: + def roll(x: mindspore.Tensor, shifts, dims=None): + """ + + Args: + x (mindspore.Tensor): Input tensor + shifts (Union[list(int), tuple(int), int]): Specifies the number of places by which elements are shifted positively (towards larger indices) along the specified dimension. Negative shifts will roll the elements in the opposite direction. + dims (Union[list(int), tuple(int), int], optional): Specifies the dimension indexes of shape to be rolled. Default: None. If dims is None, the Tensor will be flattened before rolling and then restored to the original shape. + Returns: + Tensor, has the same shape and type as input. + """ + # If dims is None, first flatten the tensor + if dims is None: + x = x.reshape(-1) + dims = 0 + + # Convert shifts and dims to lists if they are not already + if isinstance(shifts, int): + shifts = [shifts] + if isinstance(dims, int): + dims = [dims] + + # Ensure shifts and dims have the same length + if len(shifts) != len(dims): + raise ValueError("shifts and dims must have the same length") + + # Move each dimension + for shift, dim in zip(shifts, dims): + # Handle negative shifts + if shift < 0: + shift = x.shape[dim] + shift + + # Normalize shift, ensuring it is within valid range + shift = shift % x.shape[dim] + + if shift == 0: + continue + + # Split at the specified dimension + indices = list(range(x.ndim)) + indices[0], indices[dim] = indices[dim], indices[0] + x = x.swapaxes(0, dim) # Move the target dimension to the first dimension + + shape = x.shape + x = x.reshape(shape[0], -1) # Flatten the other dimensions + + # Perform roll operation + x = ops.concat([x[shape[0]-shift:], x[:shape[0]-shift]], dim=0) + + # Restore original shape + x = x.reshape(shape) + x = x.swapaxes(0, dim) # Restore dimensions + + return x + +class Mamba2Cache: + """ + Arguments: + config: Mamba2Config + batch_size: int + dtype: mindspore.dtype + + Attributes: + dtype: (`mindspore.dtype`): + The default `dtype` used to initializing the cache. + conv_kernel_size: (`int`): + Model's convolution kernel size taken from config. + n_groups: (`int`): + Model's number of groups taken from the config - similar to tensor parallel in Transformer. + state_size: (`int`): + Model's SSM state size taken from config. + num_heads: (`int`): + The number of heads used in the linear attention / SSM. + head_dim: (`int`): + The respective dimension of the heads used in the linear attention / SSM. + intermediate_size: (`int`): + Model's intermediate_size based on (expand * hidden_dim) from config. + conv_states: (`mindspore.Tensor`): + A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]` that holds convolutional states. + ssm_states: (`mindspore.Tensor`): + A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states. + """ + + def __init__( + self, config: Mamba2Config, batch_size: int, dtype: mindspore.dtype = mindspore.float16): + self.dtype = dtype + self.conv_kernel_size = config.conv_kernel + self.n_groups = config.n_groups + self.state_size = config.state_size + self.num_heads = config.num_heads + self.head_dim = config.head_dim + self.intermediate_size = int(config.expand * config.hidden_size) + + self.conv_states = ops.zeros( + (config.num_hidden_layers, + batch_size, + self.intermediate_size + 2 * self.n_groups * self.state_size, + self.conv_kernel_size), + dtype=dtype, + ) + self.ssm_states = ops.zeros( + (config.num_hidden_layers, + batch_size, + self.num_heads, + self.head_dim, + self.state_size), + dtype=dtype, + ) + + def update_conv_state( + self, layer_idx: int, new_conv_state: mindspore.Tensor, cache_init: bool = False + ) -> mindspore.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state + else: + self.conv_states[layer_idx] = roll(self.conv_states[layer_idx], shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :] + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: mindspore.Tensor): + self.ssm_states[layer_idx] = new_ssm_state + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class MambaRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(ops.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(dtype=mindspore.float32) + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate.to(dtype=mindspore.float32)) + variance = hidden_states.pow(2).mean(-1, keep_dims=True) + hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon) + + return self.weight * hidden_states.to(input_dtype) + + +class Mamba2Mixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: Mamba2Config, layer_idx: int): + super().__init__() + self.num_heads = config.num_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = int(config.expand * self.hidden_size) + self.time_step_rank = int(config.time_step_rank) + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + + self.layer_norm_epsilon = config.layer_norm_epsilon + self.rms_norm = config.rms_norm + + self.n_groups = config.n_groups + self.head_dim = config.head_dim + self.chunk_size = config.chunk_size + + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.conv_dim, + padding=config.conv_kernel - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=config.use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(ops.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = ops.arange(1, self.num_heads + 1).astype(mindspore.float32) + self.A_log = nn.Parameter(ops.log(A)) + self.A_log._no_weight_decay = True + self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.D = nn.Parameter(ops.ones(self.num_heads)) + self.D._no_weight_decay = True + + # use_bias (`bool`, *optional*, defaults to `False`) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.use_bias = config.use_bias + + # fmt: off + def mindspore_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[mindspore.Tensor]=None, attention_mask: Optional[mindspore.Tensor]=None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2 + _, _, gate, hidden_states_B_C, dt = ops.split( + projected_states, [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False) + + conv_states = cache_params.conv_states[self.layer_idx] + + hidden_states_B_C = ops.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.swapaxes(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.swapaxes(1, 2))[..., :seq_len].swapaxes(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = ops.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -ops.exp(self.A_log.float()) # [num_heads] + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + # Delete 'device' in mindspore + cache_device = cache_params.ssm_states + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.swapaxes(1, 2).broadcast_to((batch_size, dt.shape[-1], self.head_dim)) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].broadcast_to((self.dt_bias.shape[0], self.head_dim)) + + dt = nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = ops.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].broadcast_to((self.num_heads, self.head_dim, self.ssm_state_size)).to(dtype=mindspore.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (ops.exp(dt[..., None] * A)) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.broadcast_to((batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1])).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]) + + # State calculation + cache_params.update_ssm_state( + layer_idx=self.layer_idx, + new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.broadcast_to((batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1])) + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(dtype=C.dtype) # Shape: [b, h, d, n] + + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = ops.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].broadcast_to((self.D.shape[0], self.head_dim)) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = ops.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.tile((1, 1, self.num_heads // self.n_groups, 1)) + C = C.tile((1, 1, self.num_heads // self.n_groups, 1)) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = ops.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = ops.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(axis=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(axis=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(axis=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = ops.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(axis=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] + else: + previous_states = ops.zeros_like(states[:, :1]) + states = ops.cat([previous_states, states], dim=1) + decay_chunk = ops.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.swapaxes(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(axis=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = ops.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(axis=-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + + # fmt: on + def forward( + self, + hidden_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + ): + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.mindspore_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class Mamba2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm + """ + super().__init__() + self.weight = nn.Parameter(ops.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(mindspore.float32) + variance = ops.mean(hidden_states.pow(2), -1, keepdim=True) + hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class Mamba2Block(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.mixer = Mamba2Mixer(config, layer_idx=layer_idx) + + def forward( + self, + hidden_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + ): + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(dtype=mindspore.float32) + + hidden_states = self.mixer( + hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask + ) + hidden_states = residual + hidden_states + return hidden_states + + +class Mamba2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Mamba2Config + base_model_prefix = "backbone" + _no_split_modules = ["Mamba2Block"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, Mamba2Mixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt = ops.exp( + ops.rand(self.config.num_heads) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + ops.log(-ops.expm1(-dt)) + with no_grad(): + module.dt_bias.assign_value(inv_dt) + module.dt_bias._no_reinit = True + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with no_grad(): + p /= math.sqrt(self.config.num_hidden_layers) + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2 +class Mamba2Output(ModelOutput): + """ + Class for the MAMBA2 model outputs. + + Args: + last_hidden_state (`mindspore.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`Mamba2Cache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(mindspore.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `mindspore.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[mindspore.Tensor] = None + cache_params: Optional[Mamba2Cache] = None + hidden_states: Optional[Tuple[mindspore.Tensor]] = None + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2 +class Mamba2CausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`mindspore.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`mindspore.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`Mamba2Cache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(mindspore.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `mindspore.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[mindspore.Tensor] = None + logits: Optional[mindspore.Tensor] = None + cache_params: Optional[Mamba2Cache] = None + hidden_states: Optional[Tuple[mindspore.Tensor]] = None + +class Mamba2Model(Mamba2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) + self.post_init() + + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[mindspore.Tensor] = None, + inputs_embeds: Optional[mindspore.Tensor] = None, + cache_params: Optional[Mamba2Cache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + **kwargs, + ) -> Union[Tuple, Mamba2Output]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if use_cache: + if cache_params is None: + cache_params = Mamba2Cache( + self.config, inputs_embeds.shape[0], dtype=inputs_embeds.dtype + ) + cache_position = ops.arange(0, self.config.conv_kernel, dtype=mindspore.int64) + elif cache_position is None: + # cases when we do manual forward instead of using `model.generate` which will initiate + # `cache_position` and makes sure it is not None, throw error here instead of doing some + # hack to conjecture the current cache position + raise ValueError( + "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, " + "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will " + "be initialized for you automatically" + ) + else: + cache_params = None + + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + for mixer_block in self.layers: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask + ) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return Mamba2Output( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + +class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin): + _tied_weights_keys = [] + + def __init__(self, config): + super().__init__(config) + self.backbone = Mamba2Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def prepare_inputs_for_generation( + self, + input_ids=None, + inputs_embeds=None, + use_cache=None, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + **kwargs, + ): + # Overwitten -- uses `cache_params` as opposed to `past_key_values` + + if use_cache: + # `cache_position` should have been initialized in `generate` + if cache_position is None: + raise ValueError( + "`cache_position` should not be None as it should have been initialized in " + "`model.generate`, you are responsible for passing in a valid `cache_position` if " + "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" + ) + if cache_position[0] > 0: + input_ids = input_ids[:, -1][..., None] + + if attention_mask is not None: + attention_mask = None + else: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = ops.arange(0, self.config.conv_kernel, dtype=mindspore.int64) + + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "attention_mask": attention_mask, + "cache_params": cache_params, + "use_cache": use_cache, + "cache_position": cache_position, + } + ) + return model_inputs + + def forward( + self, + input_ids: Optional[mindspore.Tensor] = None, + inputs_embeds: Optional[mindspore.Tensor] = None, + cache_params: Optional[Mamba2Cache] = None, + labels: Optional[mindspore.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + **kwargs, # for now we need this for generation + ) -> Union[Tuple, Mamba2CausalLMOutput]: + r""" + labels (`mindspore.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + mamba2_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = mamba2_outputs[0] + + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + mamba2_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Mamba2CausalLMOutput( + loss=loss, + logits=logits, + cache_params=mamba2_outputs.cache_params, + hidden_states=mamba2_outputs.hidden_states, + ) + + +__all__ = ["Mamba2ForCausalLM", "Mamba2Model", "Mamba2PreTrainedModel"] diff --git a/mindnlp/utils/import_utils.py b/mindnlp/utils/import_utils.py index 547f2efd9..8a902abba 100644 --- a/mindnlp/utils/import_utils.py +++ b/mindnlp/utils/import_utils.py @@ -382,6 +382,11 @@ def is_essentia_available(): """ return _essentia_available +def is_mamba_2_ssm_available(): + return _is_package_available("mamba_ssm") + +def is_causal_conv1d_available(): + return _is_package_available("causal_conv1d") def is_pyctcdecode_available(): """ diff --git a/mindnlp/utils/testing_utils.py b/mindnlp/utils/testing_utils.py index 13d4b4f1d..466c4e4c3 100644 --- a/mindnlp/utils/testing_utils.py +++ b/mindnlp/utils/testing_utils.py @@ -262,6 +262,78 @@ def require_librosa(test_case): """ return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case) +################################################################################ +### update_wrapper() and wraps() decorator +################################################################################ + +# update_wrapper() and wraps() are tools to help write +# wrapper functions that can handle naive introspection +# Note from mamba2 model porting: Original mamba2 code require python 3.13+ +# so we copy the codes from python 3.13+ +WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__', + '__annotations__', '__type_params__') +WRAPPER_UPDATES = ('__dict__',) +def update_wrapper(wrapper, + wrapped, + assigned = WRAPPER_ASSIGNMENTS, + updated = WRAPPER_UPDATES): + """Update a wrapper function to look like the wrapped function + + wrapper is the function to be updated + wrapped is the original function + assigned is a tuple naming the attributes assigned directly + from the wrapped function to the wrapper function (defaults to + functools.WRAPPER_ASSIGNMENTS) + updated is a tuple naming the attributes of the wrapper that + are updated with the corresponding attribute from the wrapped + function (defaults to functools.WRAPPER_UPDATES) + """ + for attr in assigned: + try: + value = getattr(wrapped, attr) + except AttributeError: + pass + else: + setattr(wrapper, attr, value) + for attr in updated: + getattr(wrapper, attr).update(getattr(wrapped, attr, {})) + # Issue #17482: set __wrapped__ last so we don't inadvertently copy it + # from the wrapped function when updating __dict__ + wrapper.__wrapped__ = wrapped + # Return the wrapper so this can be used as a decorator via partial() + return wrapper + +def wraps(wrapped, + assigned = WRAPPER_ASSIGNMENTS, + updated = WRAPPER_UPDATES): + """Decorator factory to apply update_wrapper() to a wrapper function + + Returns a decorator that invokes update_wrapper() with the decorated + function as the wrapper argument and the arguments to wraps() as the + remaining arguments. Default arguments are as for update_wrapper(). + This is a convenience function to simplify applying partial() to + update_wrapper(). + """ + return functools.partial(update_wrapper, wrapped=wrapped, + assigned=assigned, updated=updated) + +def require_read_token(fn): + """ + A decorator that loads the HF token for tests that require to load gated models. + """ + token = os.getenv("HF_HUB_READ_TOKEN") + + @wraps(fn) + def _inner(*args, **kwargs): + if token is not None: + with patch("huggingface_hub.utils._headers.get_token", return_value=token): + return fn(*args, **kwargs) + else: # Allow running locally with the default token env variable + return fn(*args, **kwargs) + + return _inner + + def require_essentia(test_case): """ Decorator marking a test that requires essentia diff --git a/tests/transformers/generation/test_utils.py b/tests/transformers/generation/test_utils.py index 35b486d6a..74ed50116 100644 --- a/tests/transformers/generation/test_utils.py +++ b/tests/transformers/generation/test_utils.py @@ -1630,16 +1630,16 @@ def test_generate_from_inputs_embeds_decoder_only(self): # Traditional way of generating text outputs_from_ids = model.generate( - input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True + input_ids, max_new_tokens=1, return_dict_in_generate=True, output_scores=True ) - self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) + self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 1)) # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output) inputs_embeds = model.get_input_embeddings()(input_ids) outputs_from_embeds = model.generate( input_ids, inputs_embeds=inputs_embeds, - max_new_tokens=5, + max_new_tokens=1, return_dict_in_generate=True, output_scores=True, ) @@ -1651,7 +1651,7 @@ def test_generate_from_inputs_embeds_decoder_only(self): outputs_from_rand_embeds = model.generate( input_ids, inputs_embeds=random_embeds, - max_new_tokens=5, + max_new_tokens=1, return_dict_in_generate=True, output_scores=True, ) @@ -1660,7 +1660,7 @@ def test_generate_from_inputs_embeds_decoder_only(self): # input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same outputs_from_embeds_wo_ids = model.generate( - inputs_embeds=inputs_embeds, max_new_tokens=5, return_dict_in_generate=True, output_scores=True + inputs_embeds=inputs_embeds, max_new_tokens=1, return_dict_in_generate=True, output_scores=True ) self.assertListEqual( outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(), diff --git a/tests/transformers/models/mamba2/__init__.py b/tests/transformers/models/mamba2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/transformers/models/mamba2/test_modeling_mamba2.py b/tests/transformers/models/mamba2/test_modeling_mamba2.py new file mode 100644 index 000000000..1a684efa4 --- /dev/null +++ b/tests/transformers/models/mamba2/test_modeling_mamba2.py @@ -0,0 +1,404 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest +from typing import Dict, List, Tuple + + +from mindnlp.transformers import AutoTokenizer, Mamba2Config, is_mindspore_available +from mindnlp.utils.testing_utils import require_read_token, slow, require_mindspore + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor + + +if is_mindspore_available(): + import mindspore + from mindnlp.core import ops, nn, no_grad + + from mindnlp.transformers import ( + Mamba2ForCausalLM, + Mamba2Model, + ) + from mindnlp.transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer + + +class Mamba2ModelTester: + def __init__( + self, + parent, + batch_size=14, + num_heads=8, + n_groups=8, + state_size=2, + head_dim=8, + conv_kernel=4, + chunk_size=8, + seq_length=7, + is_training=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + hidden_act="silu", + hidden_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + num_labels=3, + num_choices=4, + scope=None, + tie_word_embeddings=False, + ): + self.parent = parent + self.num_heads = num_heads + self.n_groups = n_groups + self.head_dim = head_dim + self.state_size = state_size + self.conv_kernel = conv_kernel + self.chunk_size = chunk_size + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.bos_token_id = vocab_size - 1 + self.eos_token_id = vocab_size - 1 + self.pad_token_id = vocab_size - 1 + self.tie_word_embeddings = tie_word_embeddings + + def get_large_model_config(self): + return Mamba2Config.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1", from_pt=True) + + def prepare_config_and_inputs( + self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False + ): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + # Only left padding is valid + attention_mask = ops.ones((self.batch_size, self.seq_length), mindspore.int64) + attention_mask[0, :1] = 0 + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config( + gradient_checkpointing=gradient_checkpointing, + ) + + return ( + config, + input_ids, + attention_mask, + sequence_labels, + token_labels, + choice_labels, + ) + + def get_config(self, gradient_checkpointing=False): + return Mamba2Config( + head_dim=self.head_dim, + num_heads=self.num_heads, + n_groups=self.n_groups, + state_size=self.state_size, + conv_kernel=self.conv_kernel, + chunk_size=self.chunk_size, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + activation_function=self.hidden_act, + n_positions=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + use_cache=True, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + gradient_checkpointing=gradient_checkpointing, + tie_word_embeddings=self.tie_word_embeddings, + ) + + def prepare_config_and_inputs_for_common(self): + ( + config, + input_ids, + _, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + inputs_dict = {"input_ids": input_ids} + return config, inputs_dict + + def create_and_check_mamba2_caching(self, config, input_ids, attention_mask, *args): + model = Mamba2Model(config=config) + model.eval() + + output_whole = model(input_ids, attention_mask=attention_mask).last_hidden_state + + outputs = model( + input_ids[:, :-1], + attention_mask=attention_mask[:, :-1], + use_cache=True, + cache_position=ops.arange(0, config.conv_kernel), + ) + output_one = outputs.last_hidden_state + + # Using the state computed on the first inputs, we will get the same output + outputs = model( + input_ids[:, -1:], + attention_mask=attention_mask[:, -1:], + use_cache=True, + cache_params=outputs.cache_params, + cache_position=ops.arange(config.conv_kernel, config.conv_kernel + 1), + ) + output_two = outputs.last_hidden_state + + self.parent.assertTrue( + ops.allclose(ops.cat([output_one, output_two], dim=1), output_whole, atol=1e-3, rtol=1e-3) + ) + +@require_mindspore +class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (Mamba2Model, Mamba2ForCausalLM) if is_mindspore_available() else () + all_generative_model_classes = (Mamba2ForCausalLM,) if is_mindspore_available() else () + has_attentions = False # Mamba does not support attentions + fx_compatible = False # FIXME let's try to support this @molbap + test_missing_keys = False + test_model_parallel = False + test_pruning = False + test_head_masking = False # Mamba does not have attention heads + + pipeline_model_mapping = ( + {"feature-extraction": Mamba2Model, "text-generation": Mamba2ForCausalLM} if is_mindspore_available() else {} + ) + + def setUp(self): + self.model_tester = Mamba2ModelTester(self) + self.config_tester = ConfigTester( + self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] + ) + + @unittest.skip(reason="Skipped in mamba") + def test_mamba2_caching(self): + pass + # config_and_inputs = self.model_tester.prepare_config_and_inputs() + # self.model_tester.create_and_check_mamba2_caching(*config_and_inputs) + + def test_initialization(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config=config) + for name, param in model.named_parameters(): + if "D" in name: + if param.requires_grad: + # check if it's a ones like + assert ops.allclose(param.data, ops.ones_like(param.data), rtol=1e-5, atol=1e-5) + + @unittest.skip(reason="Mamba 2 weights are not tied") + def test_tied_weights_keys(self): + pass + + @unittest.skip(reason="A large mamba2 would be necessary (and costly) for that") + def test_multi_gpu_data_parallel_forward(self): + pass + + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + with no_grad(): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, Mamba2Cache): # MODIFIED PART START + recursive_check(tuple_object.conv_states, dict_object.conv_states) + recursive_check(tuple_object.ssm_states, dict_object.ssm_states) + elif isinstance(tuple_object, (List, Tuple)): # MODIFIED PART END + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip( + tuple_object.values(), dict_object.values() + ): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + ops.allclose(tuple_object, dict_object, atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {ops.max(ops.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {ops.isnan(tuple_object).any()} and `inf`: {ops.isinf(tuple_object)}. Dict has" + f" `nan`: {ops.isnan(dict_object).any()} and `inf`: {ops.isinf(dict_object)}." + ), + ) + + recursive_check(tuple_output, dict_output) + + for model_class in self.all_model_classes: + model = model_class(config) + model.eval() + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + +@require_mindspore +@slow +@require_read_token +class Mamba2IntegrationTest(unittest.TestCase): + def setUp(self): + self.model_id = "mistralai/Mamba-Codestral-7B-v0.1" + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False, from_pt=True) + self.prompt = ("[INST]Write a hello world program in C++.",) + + @require_read_token + @slow + @require_mindspore + def test_simple_generate(self): + """ + Simple generate test to avoid regressions. + Note: state-spaces (cuda) implementation and pure torch implementation + have irreconciliable differences as of now, which will cause this test to fail + in an environment with state-spaces installed. + """ + tokenizer = self.tokenizer + tokenizer.pad_token_id = tokenizer.eos_token_id + + model = Mamba2ForCausalLM.from_pretrained(self.model_id, mindspore_dtype=mindspore.bfloat16, from_pt=True) + input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"] + + out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30) + output_sentence = tokenizer.decode(out[0]) + ground_truth_sentence = """[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include \n\n""" + assert output_sentence == ground_truth_sentence + + @require_read_token + @slow + @require_mindspore + def test_batched_equivalence_with_cache(self): + """ + Verifies that batched generation matches individual generation. + Important because of the specific caching mechanism + statefulness of mamba model. + Depending on precision and devices, differences can be observed from generation to generation. + """ + tokenizer = self.tokenizer + prompt = [ + "[INST]Write C#.[/INST]", + "[INST]Write a hello world in C++.[/INST]", + "[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]", + ] + + model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=mindspore.bfloat16, from_pt=True) + tokenizer.pad_token_id = tokenizer.eos_token_id + # batched generation + tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest") + batched_gen = model.generate(**tokenized_prompts, max_new_tokens=30, use_cache=True) + batched_output = tokenizer.batch_decode(batched_gen, skip_special_tokens=True) + + # individual generation + + for index_gen, individual_prompt in enumerate(prompt): + inputs = tokenizer(individual_prompt, return_tensors="pt", padding="longest") + individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True) + individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0] + assert individual_output[:100] == batched_output[index_gen][:100] + + @require_read_token + @slow + def test_batched_equivalence_without_cache(self): + """ + Verifies that batched generation matches individual generation without cache. + Important because of the specific caching mechanism + statefulness of mamba model. + Depending on precision and devices, differences can be observed from generation to generation. + """ + tokenizer = self.tokenizer + prompt = [ + "[INST]Write C#.[/INST]", + "[INST]Write a hello world in C++.[/INST]", + "[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]", + ] + + model = Mamba2ForCausalLM.from_pretrained(self.model_id, mindspore_dtype=mindspore.bfloat16, from_pt=True) + tokenizer.pad_token_id = tokenizer.eos_token_id + # batched generation + tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest") + batched_gen = model.generate(**tokenized_prompts, max_new_tokens=30, use_cache=True) + batched_output = tokenizer.batch_decode(batched_gen, skip_special_tokens=True) + + # individual generation + + for index_gen, individual_prompt in enumerate(prompt): + inputs = tokenizer(individual_prompt, return_tensors="pt", padding="longest") + individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True) + individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0] + assert individual_output[:100] == batched_output[index_gen][:100] + + @slow + @require_mindspore + def test_mamba2_mixer_train_vs_eval_equivalence(self): + # Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63 + # Credit to zhixuan-lin + + B, T, D = 4, 512, 768 + dtype = mindspore.bfloat16 + config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1) + + mindspore.set_seed(42) + with mindspore.amp.autocast(dtype=dtype): + with no_grad(): + mixer = Mamba2Mixer(config, layer_idx=0) + hidden_states = ops.rand(size=(B, T, D), dtype=dtype) + + mixer.train() + out_train = mixer(hidden_states) + + mixer.eval() + out_eval = mixer(hidden_states) + + assert ops.allclose(out_train, out_eval, rtol=1e-3, atol=1e-3) \ No newline at end of file