Sunday, March 24, 2024

Mamba: SSM, Idea, and Implementation in Keras and TensorFlow | by Vedant Jumle | Mar, 2024

Must read


Understanding how SSMs and Mamba work, together with the right way to get began with implementing it in Keras and TensorFlow.

Supply: AI Generate (SDXL)

Submitted on 1st December, 2023 on arXiv, the paper titled “Mamba: Linear-Time Sequence Modeling with Selective State Areas” proposed an fascinating strategy to sequence modeling. The authors — Albert Gu, Tri Dao — launched, ‘Mamba’ that utilized ‘selective’ state house fashions (SSM) to attain outcomes that compete with the efficiency of the, now ubiquitous, Transformer mannequin.

Transformers have seen latest reputation with the rise of Giant Language Fashions (LLMs) like LLaMa-2, GPT-4, Claude, Gemini, and so forth., nevertheless it suffers from the issue of context window. The problem with transformers lies in it’s core, the multi head-attention mechanism.

The primary subject with multi-head consideration sprouts from the truth that for enter sequence size n, the time complexity and house complexity scales by O(n²). This limits the size of the context window of an LLM. As a result of, to extend it by 10x, we have to scale the {hardware} requirement (most notably GPU VRAM) by 100x.

Mamba, however, scales by O(n)!, i.e., Linearly.

Plot taken from the Mamba paper evaluating FlashAttention and Mamba strategy (indicated by scan(ours) within the legends)[1]

This linear scaling is what has taken wind for researchers to invest that Mamba may be the way forward for sequence modeling.

The core of the Mamba mannequin comes from the idea of State House Fashions. State House Fashions, like Transformers and RNN, course of sequences of data, like textual content, audio indicators, video frames, DNA sequences, and so forth.

State House Fashions come from an thought of describing a bodily system as a set of enter, outputs, and variables. These variables are: A, B, C, D. The method of SSM entails calculation of an inner state vector h(t), given an enter x(t). Then, we do a weighted sum of h(t) and x(t) the place the weights are A, B, C, & D. Within the easiest type (steady time-invariant), the method formulation appears like:

supply: wikipedia[6]

h(t) is commonly referred to as the ‘hidden’ or the ‘latent’ state, I might be sticking to calling it the ‘hidden’ state for higher readability. It is very important notice that A, B, C, and D are learnt parameters in SSM.

What are the variables?

The variables, A, B, C & D, are learnt parameters, and they are often described as:

  • A: How a lot ought to the earlier hidden state (h) be thought of to calculate the brand new hidden state
  • B: How a lot ought to the enter (x) be contemplate to calculate the brand new hidden state.
  • C: How a lot ought to the brand new hidden state be thought of in calculating the output (y).
  • D: How a lot ought to the enter (x) be contemplate in calculating the output (y).

D comes ultimately of the computations and doesn’t have an effect on how the hidden state is calculated. Therefore, it’s normally thought of exterior of ssm, and might be considered a skip connection.

Going from steady areas to discrete areas

The above formulation applies to a system the place the enter and output belong to a steady house. However in circumstances, like language modeling, the place the enter and output belong to discrete areas (token values in a vocabulary). Additionally, discovering h(t) is analytically difficult. This may be achieved by performing a Zero-order maintain.

In a zero-order maintain, each time an enter is acquired, the mannequin holds its worth until the following enter is acquired. This results in a steady enter house.

How Zero order maintain works

This size of ‘maintain’ is set by a brand new parameter referred to as, step measurement ∆. It may be considered the decision of the enter. Ideally, ∆ needs to be infinitesimal.

Mathematically, Zero-order maintain might be described as:

Lastly, we will create a discrete SSM, as:

Since, D is used with a skip connection exterior of SSM, the output might be decreased to:

Involvement of DX(t) is taken into account as a skip connection, therefore is goes from exterior of SSM

In SSMs, the hidden state is carried over to when the following enter is acquired. That is just like how Recurrent Neural Networks perform.

Comparability of RNN and SSM

This recurrent format of SSM might be unwrapped, identical to RNNs. However in contrast to RNNs, that are iterative and sluggish, SSM can course of the enter sequence in parallel (identical to transformers) and this makes the coaching processes quicker.

Unrolled type of SSM

Notice that ‘D’ is utilized in a skip connection, which is exterior of SSM.

The important thing perception in how SSM make coaching quick is to make use of the variables A, B, C in a pre-computed convolutional kernel. Maarten Grootendorst wrote a very good rationalization on how this canonical ‘convolutional’ kernel is constructed. However right here’s a easy mathematical rationalization.

Contemplate the output y. For a sequence size of ok, the output for y(ok) might be represented (assuming h0 = zero):

Equally, y3 might be represented as:

Extrapolating the sample, yk might be represented as:

This formulation might be additional decreased to:

The humorous trying multiplication image represents a convolution operation, the place the convolution kernel is Okay. Discover that Okay will not be depending on x, therefore Okay might be pre-computed right into a convolutional kernel, which makes the method quicker.

Pretty much as good because the computational capability of SSM sounds, it seems to be fairly meh in metrics like accuracy in comparison with Transformers.

The core subject lies with the variables, ∆, A, B, & C. Seems that since we apply the identical matrices to each enter, they can’t actually course of the context of the sequence.

SSMs are rigid in the best way they course of information[4]

So what’s so particular about Mamba? In mamba, we use a course of referred to as ‘selective’ SSM, the place the variables, ∆, B, & C, are computed based mostly on the enter. 🤔. We do that by passing the present enter by way of Linear layers, and take the output to be the ∆, B, & C.

However then this makes ∆, B, & C enter dependent, therefore that means that they can’t be pre-computed 😢, quick convolution isn’t going to work right here. However, the authors talk about a way, which is predicated on parallel associative scan.

Parallel Associative Scan

Parallel associative scan is a strong method utilized in parallel computing to carry out a prefix sum operation, which is a cumulative operation on a sequence of numbers. This operation is “associative”, that means the best way numbers are grouped within the operation doesn’t change the end result.

Parallel prefix sum is an instance of associative scanning. (supply: Nvidia)[7]

Within the context of the Mamba mannequin, by defining an associative operator, parts and associative operators for a parallel associative scan operation are obtained. This enables for fixing issues on the entire time interval in parallel, leading to logarithmic time complexity within the variety of sub-intervals.

{Hardware} conscious algorithm

Together with associative scan, the authors additionally suggest a {hardware} conscious algorithm, the place they use the quirks inside Nvidia GPUs associated to the pace of HBM and SRAM. They argue that the computation of SSM states might be sped up by:

  • protecting the hidden state and A within the quicker however much less capability SRAM,
  • whereas computing ∆, B, & C, within the slower however bigger capability HBM.
  • They then switch ∆, B, & C to the SRAM, compute the brand new hidden state inside SRAM.
  • After which write ∆, B & C again to HBM.
Illustration taken from the Mamba paper, it reveals how the {hardware} conscious algorithm works[1]

Within the implementation part, I cannot be discussing on the right way to work with the {hardware} conscious algorithm, quite I might be solely utilizing parallel associative scan.

With all of this in thoughts, let’s discover and implement the Mamba structure utilizing Keras and TensorFlow.

The Mamba structure, after studying the paper and evaluation of the code, might be damaged into a number of key elements that are related as:

Breakdown of a mamba block

The Mamba structure consists of a number of stacked layers of ‘Mamba blocks’. Which, judging from the above illustration, consists of fairly a number of elements. One other essential factor to notice is that the authors add the output from Selective SSM to the unique enter after which apply a normalization layer to it. This normalization might be both a Layer normalization or an RMS normalization.

Lets begin with coding a part of Mamba. We’ll utilizing the next dependencies:

tensorflow[and-cuda]==2.15.0.post1 # if you wish to use GPU or
tensorflow==2.15.0.post1 # if you wish to solely use CPU
transformers==4.36.2 # for utilizing the bert tokenizer
einops==0.7.0 # helpful to make matrix manipulation quicker
datasets==2.16.1 # to load datasets
# all different modules (like numpy) might be auto put in

Imports:

import tensorflow_datasets as tfds
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers, Mannequin

from dataclasses import dataclass
from einops import rearrange, repeat
from typing import Union

from transformers import AutoTokenizer

import datasets
import math
import numpy as np

To make the modeling argument processing simpler, let’s create a easy ModelArgs dataclass as a config class. This enables us to only move the dataclass variable within the arguments after we are initializing the mannequin.

@dataclass
class ModelArgs:
model_input_dims: int = 64
model_states: int = 64
projection_expand_factor: int = 2
conv_kernel_size: int = 4
delta_t_min: float = 0.001
delta_t_max: float = 0.1
delta_t_scale: float = 0.1
delta_t_init_floor: float = 1e-4
conv_use_bias: bool = True
dense_use_bias: bool = False
layer_id: int = -1
seq_length: int = 128
num_layers: int = 5
dropout_rate: float = 0.2
use_lm_head: float = False
num_classes: int = None
vocab_size: int = None
final_activation = None
loss:Union[str, keras.losses.Loss] = None
optimizer: Union[str, keras.optimizers.Optimizer] = keras.optimizers.AdamW()
metrics = ['accuracy']

def __post_init__(self):
self.model_internal_dim: int = int(self.projection_expand_factor * self.model_input_dims)

self.delta_t_rank = math.ceil(self.model_input_dims/16)
if self.layer_id == -1:
self.layer_id = np.spherical(np.random.randint(0, 1000), 4)

if self.vocab_size == None:
elevate ValueError("vocab measurement can't be none")

if self.use_lm_head:
self.num_classes=self.vocab_size
else:
if self.num_classes == None:
elevate ValueError(f'num courses can't be {self.num_classes}')

if self.num_classes == 1:
self.final_activation = 'sigmoid'
else:
self.final_activation = 'softmax'

if self.loss == None:
elevate ValueError(f"loss can't be {self.loss}")

Load the bert-base-uncased tokenizer:

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
vocab_size = tokenizer.vocab_size

Earlier than we implement our Mamba and SSM courses, we have to implement the parallel associative scan, the code appears like this:

def selective_scan(u, delta, A, B, C, D):
# first step of A_bar = exp(ΔA), i.e., ΔA
dA = tf.einsum('bld,dn->bldn', delta, A)
dB_u = tf.einsum('bld,bld,bln->bldn', delta, u, B)

dA_cumsum = tf.pad(
dA[:, 1:], [[0, 0], [1, 1], [0, 0], [0, 0]])[:, 1:, :, :]

dA_cumsum = tf.reverse(dA_cumsum, axis=[1]) # Flip alongside axis 1

# Cumulative sum alongside all of the enter tokens, parallel prefix sum,
# calculates dA for all of the enter tokens parallely
dA_cumsum = tf.math.cumsum(dA_cumsum, axis=1)

# second step of A_bar = exp(ΔA), i.e., exp(ΔA)
dA_cumsum = tf.exp(dA_cumsum)
dA_cumsum = tf.reverse(dA_cumsum, axis=[1]) # Flip again alongside axis 1

x = dB_u * dA_cumsum
# 1e-12 to keep away from division by 0
x = tf.math.cumsum(x, axis=1)/(dA_cumsum + 1e-12)

y = tf.einsum('bldn,bln->bld', x, C)

return y + u * D

With this, we will implement the MambaBlock:

class MambaBlock(layers.Layer):
def __init__(self, modelargs: ModelArgs, *args, **kwargs):
tremendous().__init__(*args, **kwargs)
self.args = modelargs
args = modelargs
self.layer_id = modelargs.layer_id

self.in_projection = layers.Dense(
args.model_internal_dim * 2,
input_shape=(args.model_input_dims,), use_bias=False)

self.conv1d = layers.Conv1D(
filters=args.model_internal_dim,
use_bias=args.conv_use_bias,
kernel_size=args.conv_kernel_size,
teams=args.model_internal_dim,
data_format='channels_first',
padding='causal'
)

# this layer takes in present token 'x'
# and outputs the input-specific Δ, B, C (in response to S6)
self.x_projection = layers.Dense(args.delta_t_rank + args.model_states * 2, use_bias=False)

# this layer tasks Δ from delta_t_rank to the mamba inner
# dimension
self.delta_t_projection = layers.Dense(args.model_internal_dim,
input_shape=(args.delta_t_rank,), use_bias=True)

self.A = repeat(
tf.vary(1, args.model_states+1, dtype=tf.float32),
'n -> d n', d=args.model_internal_dim)

self.A_log = tf.Variable(
tf.math.log(self.A),
trainable=True, dtype=tf.float32,
title=f"SSM_A_log_{args.layer_id}")

self.D = tf.Variable(
np.ones(args.model_internal_dim),
trainable=True, dtype=tf.float32,
title=f"SSM_D_{args.layer_id}")

self.out_projection = layers.Dense(
args.model_input_dims,
input_shape=(args.model_internal_dim,),
use_bias=args.dense_use_bias)

def name(self, x):
"""Mamba block ahead. This appears the identical as Determine 3 in Part 3.4 within the Mamba pape.
Official Implementation:
class Mamba, https://github.com/state-spaces/mamba/blob/primary/mamba_ssm/modules/mamba_simple.py#L119
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/primary/mamba_ssm/ops/selective_scan_interface.py#L311
"""

(batch_size, seq_len, dimension) = x.form

x_and_res = self.in_projection(x) # form = (batch, seq_len, 2 * model_internal_dimension)
(x, res) = tf.break up(x_and_res,
[self.args.model_internal_dim,
self.args.model_internal_dim], axis=-1)

x = rearrange(x, 'b l d_in -> b d_in l')
x = self.conv1d(x)[:, :, :seq_len]
x = rearrange(x, 'b d_in l -> b l d_in')

x = tf.nn.swish(x)
y = self.ssm(x)
y = y * tf.nn.swish(res)
return self.out_projection(y)

def ssm(self, x):
"""Runs the SSM. See:
- Algorithm 2 in Part 3.2 within the Mamba paper
- run_SSM(A, B, C, u) in The Annotated S4
Official Implementation:
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/primary/mamba_ssm/ops/selective_scan_interface.py#L311
"""
(d_in, n) = self.A_log.form

# Compute ∆ A B C D, the state house parameters.
# A, D are enter impartial (see Mamba paper [1] Part 3.5.2 "Interpretation of A" for why A is not selective)
# ∆, B, C are input-dependent (this can be a key distinction between Mamba and the linear time invariant S4,
# and is why Mamba is named **selective** state areas)

A = -tf.exp(tf.solid(self.A_log, tf.float32)) # form -> (d_in, n)
D = tf.solid(self.D, tf.float32)

x_dbl = self.x_projection(x) # form -> (batch, seq_len, delta_t_rank + 2*n)

(delta, B, C) = tf.break up(
x_dbl,
num_or_size_splits=[self.args.delta_t_rank, n, n],
axis=-1) # delta.form -> (batch, seq_len) & B, C form -> (batch, seq_len, n)

delta = tf.nn.softplus(self.delta_t_projection(delta)) # form -> (batch, seq_len, model_input_dim)

return selective_scan(x, delta, A, B, C, D)

Lastly, a residual block to implement the exterior skip connection.

class ResidualBlock(layers.Layer):
def __init__(self, modelargs: ModelArgs, *args, **kwargs):
tremendous().__init__(*args, **kwargs)
self.args = modelargs
self.mixer = MambaBlock(modelargs)
self.norm = layers.LayerNormalization(epsilon=1e-5)

def name(self, x):
"""
Official Implementation:
Block.ahead(), https://github.com/state-spaces/mamba/blob/primary/mamba_ssm/modules/mamba_simple.py#L297

Notice: the official repo chains residual blocks that seem like
[Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
the place the primary Add is a no-op. That is purely for efficiency causes as this
permits them to fuse the Add->Norm.

We as a substitute implement our blocks because the extra acquainted, easier, and numerically equal
[Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....

"""
return self.mixer(self.norm(x)) + x

With this, we will initialize our mannequin. On this instance, I might be demonstrating the right way to use the Mamba block to create a easy classification mannequin, however it may be simply modified to turn into a language mannequin. Let’s load the IMDB opinions dataset for a easy sentiment classifier.

from datasets import load_dataset
from tqdm import tqdm

dataset = load_dataset("ajaykarthick/imdb-movie-reviews")

First we create a perform that may take the mannequin args and return a mannequin.

def init_model(args: ModelArgs):
input_layer = layers.Enter(form=(args.seq_length,), title='input_ids')
x = layers.Embedding(
args.vocab_size,
args.model_input_dims,
input_length=args.seq_length)(input_layer)

for i in vary(args.num_layers):
x = ResidualBlock(args, title=f"Residual_{i}")(x)
x = layers.Dropout(args.dropout_rate)(x) # for regularization

x = layers.LayerNormalization(epsilon=1e-5)(x) # normalization layer

# use flatten provided that we aren't utilizing the mannequin as an LM
if not args.use_lm_head:
x = layers.Flatten()(x)
x = layers.Dense(1024, activation=tf.nn.gelu)(x)
output_layer = layers.Dense(
args.num_classes,
activation=args.final_activation)(x)

mannequin = Mannequin(
inputs=input_layer,
outputs=output_layer, title='Mamba_ka_Mamba')
mannequin.compile(
loss=args.loss,
optimizer=args.optimizer,
metrics=args.metrics
)

return mannequin

Now we will initialize our mannequin, and summarize it:

args = ModelArgs(
model_input_dims=128,
model_states=32,
num_layers=12,
dropout_rate=0.2,
vocab_size=vocab_size,
num_classes=1,
loss='binary_crossentropy',
)
mannequin = init_model(args)
mannequin.abstract()
Mannequin: "Mamba_ka_Mamba"
_________________________________________________________________
Layer (kind) Output Form Param #
=================================================================
input_ids (InputLayer) [(None, 128)] 0

embedding_2 (Embedding) (None, 128, 128) 3906816

Residual_0 (ResidualBlock) (None, 128, 128) 129024

dropout_24 (Dropout) (None, 128, 128) 0

Residual_1 (ResidualBlock) (None, 128, 128) 129024

dropout_25 (Dropout) (None, 128, 128) 0

... (I've shrinked this to make it extra readable)

dropout_35 (Dropout) (None, 128, 128) 0

layer_normalization_38 (La (None, 128, 128) 256
yerNormalization)

flatten_2 (Flatten) (None, 16384) 0

dense_148 (Dense) (None, 1024) 16778240

dense_149 (Dense) (None, 1) 1025

=================================================================
Complete params: 22234625 (84.82 MB)
Trainable params: 22234625 (84.82 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

For simpler processing, lets pre-tokenize our information right into a numpy arrays, then convert them into tf.information.Dataset objects:

train_labels, test_labels = [], []
train_ids = np.zeros((len(dataset['train']), args.seq_length))
test_ids = np.zeros((len(dataset['test']), args.seq_length))

for i, merchandise in enumerate(tqdm(dataset['train'])):
textual content = merchandise['review']
train_ids[i, :] = tokenizer.encode_plus(
textual content,
max_length=args.seq_length,
padding='max_length',
return_tensors='np')['input_ids'][0][:args.seq_length]

train_labels.append(merchandise['label'])

for i, merchandise in enumerate(tqdm(dataset['test'])):
textual content = merchandise['review']
test_ids[i, :] = tokenizer.encode_plus(
textual content,
max_length=args.seq_length,
padding='max_length',
return_tensors='np')['input_ids'][0][:args.seq_length]

test_labels.append(merchandise['label'])

del dataset # delete the unique dataset to avoid wasting reminiscence

BATCH_SIZE = 32
train_dataset = tf.information.Dataset.from_tensor_slices((train_ids, train_labels)).batch(BATCH_SIZE).shuffle(1000)
test_dataset = tf.information.Dataset.from_tensor_slices((test_ids, test_labels)).batch(BATCH_SIZE).shuffle(1000)

Now the mannequin might be skilled:

historical past = mannequin.match(train_dataset, validation_data=test_dataset, epochs=10)

You possibly can mess around with the inference algorithm:

def infer(textual content: str, mannequin: Mannequin, tokenizer):
tokens = tokenizer.encode(
"Hi there what's up",
max_length=args.seq_length,
padding='max_length', return_tensors='np')
output = mannequin(tokens)[0, 0]
return output

This mannequin might be transformed right into a language mannequin and algorithms like beam search, top-k sampling, grasping sampling, and so forth. can be utilized to generate language.

This code might be discovered on my Github.

Quite a lot of the code is impressed from the mamba’s official implementation[2] and one other pytorch implementation referred to as ‘mamba-tiny’[3]

Thanks for studying.



Supply hyperlink

More articles

LEAVE A REPLY

Please enter your comment!
Please enter your name here

Latest article