Mamba's S6 by Hand ✍️
[1D, 4 tokens, 2 hidden states]
Just one more thing to do before my Computer Vision course ends this semester---grading! But before I get consumed by grading for the rest of the day, let me share my hands-on exercise on Mamba as promised (Pascal Biese, Steve Solun).
The Mamba paper was released on arxiv on Dec. 1. It has since generated quite a buzz, such as posts by Agnieszka Mikołajczyk, Evandro Barros. It is touted as the first linear-time model that beats the transformer model (which is based on a quadratic-time attention mechanism).
At its core, Mamba is based on the new S6 model, which stands for Structured State-Space Sequence modeling using Selective Scan.
I want to thank the first-author of the Mamba paper, Prof. Albert Gu at CMU, for verifying the accuracy of my understanding of how Mamba's S6 model works. 🙏
-- 𝗪𝗮𝗹𝗸𝘁𝗵𝗿𝗼𝘂𝗴𝗵 --
1. Left: All four tokens in the input sequence are processed by a linear layer to calculate a set of weights.
2. Right: These weights are used to drive an RNN-like network.
3. The first input [3] and hidden states [0, 0] are linearly combined using weights B=[-1; 2] and A=[1, 0; 0, -1] to calculate new hidden states -> [-3, 6]. Note that NO non-linear activation function is involved.
4. Hidden states [-3, 6] are linearly combined using weights C = [-2, -3] to obtain the first output [-12]
5. Repeats steps 3 and 4 using different sets of weights A, B, C.
-- 𝗖𝗼𝗺𝗽𝗮𝗿𝗶𝘀𝗼𝗻𝘀 𝘁𝗼 𝗮𝗻 𝗥𝗡𝗡 --
(Pre-requesite: My previous post about RNN)
💡 𝗣𝗮𝗿𝗮𝗺𝗲𝘁𝗲𝗿𝘀: Rather than reusing the same set of weight parameters, each step in the sequence uses a different set of weight parameters. These parameters are "predicted" from the entire input sequence, rather than being trained directly as in the case of a convetional RNN.
💡 𝗟𝗶𝗻𝗲𝗮𝗿𝗶𝘁𝘆: Unlike a conventional RNN, S6 does not use any non-linear activation function.
-- 𝗦𝟲'𝘀 𝗦𝗶𝘅 𝗦'𝘀 --
1️⃣ 𝗦𝗲𝗹𝗲𝗰𝘁𝗶𝘃𝗲: Weights in each step are selectively set by a linear layer.
2️⃣ 𝗦𝗰𝗮𝗻: Because each step's weights are different, it needs to scan through the input sequence to calculate each output token.
3️⃣ 𝗦𝘁𝗿𝘂𝗰𝘁𝘂𝗿𝗲𝗱: The "A" matrices, which are the square matrices used to combine hidden states, assume a certain structure to simplify calculation. In the paper, the assumed structure is diagonal.
4️⃣ 5️⃣ 𝗦𝘁𝗮𝘁𝗲-𝗦𝗽𝗮𝗰𝗲: There are "hidden states" that mutate.
6️⃣ 𝗦𝗲𝗾𝘂𝗲𝗻𝗰𝗲: It is a sequence-to-sequence model.
Note that S6 extends the previous S4 model. The first two S's--Selective, Scan, are new.
Thanks for reading! Feel free to leave your questions in the comments!
Reference:
[1] Albert Gu, Tri Dao, Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv:2312.00752