π€ AI Summary
To address the prohibitively high manual exploration cost arising from the vast design space of hybrid neural architectures under large-scale pretraining, this paper introduces Composerβthe first scalable neural architecture search (NAS) framework tailored for hybrid architectures. Its core innovations are: (i) modular modeling of attention and MLP components, enabling fine-grained architectural customization; and (ii) a novel scaling extrapolation strategy that enables efficient transfer from small-scale search to large models (350Mβ3B parameters). Evaluated on the Llama 3.2 benchmark, architectures discovered by Composer achieve consistently lower validation loss and yield downstream task accuracy gains of 1.1β3.1 percentage points (up to +8.3%), while maintaining competitive training and inference efficiency. This work establishes the first systematic methodology for efficient NAS and cross-scale generalization in hybrid architectures.
π Abstract
Hybrid model architectures that combine computational primitives (e.g., Attention, MLP) in different ratios have shown promising performance beyond Transformers. Some studies have shown that different interleavings of primitives can affect model quality as well. However, prior works explore the hybrid model architecture design space manually. Due to the large design space and training costs, discovering hybrid models that combine key computational primitives for pre-training is challenging. In this work, we take a principled approach in designing a modular hybrid model architecture search framework -- Composer. Composer explores model architectures at a small scale and extrapolates the top-performing model architectures to a larger scale using our proposed scaling strategies. Using Composer, we discover new hybrid LLM architectures that outperform Llama 3.2. Compared to Llama 3.2 and previous state-of-the-art baselines, the new model architectures consistently reduce validation loss at parameter scales of 350M-3B and improve evaluation accuracy on the downstream tasks by up to 2.8-8.3% (1.1-3.1% on average) while improving both training and inference efficiency.