Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture
Daniel Y. Fu, Simran Arora, Jessica Grogan, Isys Johnson, Sabri Eyuboglu, Armin W. Thomas, Benjamin Spector, Michael Poli, Atri Rudra, Christopher Ré
A number of great papers came out of Christopher Re’s lab over the past few years, bringing ideas from database design and classical signal processing to neural sequence modelling. In particular, leveraging GPU memory architecture to derive hardware-aware implementation (FlashAttention) of Transformer mechanism, and adopting state-space models on continuous signals for discrete language modelling (S4, H3), respectively. In both cases, these contributions enabled model training with long-range context and reduced hardware resources.
In this paper, authors tackle the quadratic runtime scaling problem ( in the sequence length) of attention architectures by building on the ideas above. As demonstrated in prior works, long-convolution based architectures proved to be powerful and promising replacements for attention modules. They possess much lower asymptotic runtime (via FFT implementation), however, suffer from poor GPU utilisation (only ). By leveraging previously introduced expressive structured (block-diagonal) Monarch matrices authors propose a Monarch Mixer architecture, which exhibits sub-quadratic runtime and much higher GPU utilization of , thus allowing training on increased sequence lengths.
In the particular case of sequence prediction, enforcing causal relationship between input and output tokens is essential, which is lost in the FFT implementation. Authors derive a novel interpretation of Monarch matrix multiplication as a multivariate polynomial evaluation and interpolation, which I found particularly surprising and interesting.
Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture