Switch Transformer by Hand ✍️
Google recently released Gemini 1.5, which adds Sparse Mixture of Experts to Gemini 1.0's architecture.
The idea of Sparse Mixture of Experts first appeared in the Switch Transformer model described in a 2022 article in the Journal of Machine Learning Research:
"Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity" by William (Liam) Fedus, Barret Zoph, Noam Shazeer.
How does a Switch Transformer work?
-- 𝗪𝗮𝗹𝗸𝘁𝗵𝗿𝗼𝘂𝗴𝗵 --
[1] Given
↳ Input features (X1-X5) from the previous block
🟨 🇦🇹🇹🇪🇳🇹🇮🇴🇳
[2] Attention Matrix
↳ Feed all 5 features to a query-key attention module (QK) to obtain an attention weight matrix (A).
[3] Pooling
↳ Multiply the input features with the attention weight matrix to obtain attention weighted features (Z1-Z5).
↳ The effect is to combine features across positions (horizontally)
[4] Visualize Pooling
↳ Z4 := X4 + X5 because the 4th column in the attention weight matrix A is [0,0,0,1,1]
🟩 🇸🇼🇮🇹🇨🇭
[5] Gate Values
↳ Multiply attention weighted features (Zs) with switch matrix
↳ For each weighted feature Zi, each gate value indicates how well an expert (A,B,C) can probably handle the feature
[6] Top Expert
↳ Find the row (Expert ID) with the highest gate value
↳ "Sparse" refers to the selection of only the top expert, not all the experts.
🟦 🇫🇫🇳
[7] Routing
↳ Each Z is routed to the best expert
↳ Each expert has a fixed capacity of 2.
↳ Note that Z5, which is supposed to go to Expert C, exceed's the capacity.
[8] Expert A: Linear Layer
↳ Apply the linear layer (multiply Z1 with weights and biases)
↳ The effect is to combine features across feature dimensions (vertically).
[9] Expert A: Aggregate
↳ Send the resulting combined feature to corresponding output column.
↳ The paper has an extra scaling step, omitted here for simplicity.
[10-11] Expert B
↳ Repeat [8] and [9]
[12-13] Expert C
↳ Repeat [8] and [9]
↳ However, since Z5 exceeded Expert C's capacity, it was simply passed through as is to the next block.
-- 𝗔𝗰𝗸𝗻𝗼𝘄𝗹𝗲𝗱𝗴𝗺𝗲𝗻𝘁 --
Lee Gao gave me several useful insights based on publicly available knowledge, without giving away any company secret. 😉
❌ ERRATA
Nadia Abdelrahman
The C row in Switch matrix should be
[0, 1, 1] not [-1, 1, 1]
Tin Tran Trung
FFN: Column 5 should be Z5 = [5,3,1], as described in the walkthrough. But in the animation, I wrote down the gate values [2,3,4], which is incorrect.


