Birdie: Advancing State Space Models with Reward-Driven Objectives and Curricula

๐Ÿ“… 2024-11-01
๐Ÿ›๏ธ arXiv.org
๐Ÿ“ˆ Citations: 0
โœจ Influential: 0
๐Ÿ“„ PDF
๐Ÿค– AI Summary
State Space Models (SSMs) underperform Transformers on long-context retrieval tasksโ€”e.g., text copying, associative recall, and long-document QAโ€”despite their favorable O(L) time complexity. Method: We propose a training paradigm shift that preserves the SSM architecture: (1) a bidirectional linear state space modeling mechanism enabling seamless switching between understanding and generation modes; (2) a dynamic multi-objective pretraining framework jointly optimizing bidirectional input modeling and causal modeling; and (3) reinforcement learning to explicitly optimize retrieval-oriented objectives. The method is fully compatible with JAX and PyTorch. Results: Our approach substantially outperforms baseline SSMs on tasks including multi-digit phonebook lookup and long-passage QA, significantly narrowing the performance gap with Transformers while strictly maintaining O(L) computational complexity.

Technology Category

Application Category

๐Ÿ“ Abstract
Efficient state space models (SSMs), such as linear recurrent neural networks and linear attention variants, offer computational advantages over Transformers but struggle with tasks requiring long-range in-context retrieval-like text copying, associative recall, and question answering over long contexts. Previous efforts to address these challenges have focused on architectural modifications, often reintroducing computational inefficiencies. In this paper, we propose a novel training procedure, Birdie, that significantly enhances the in-context retrieval capabilities of SSMs without altering their architecture. Our approach combines bidirectional input processing with dynamic mixtures of specialized pre-training objectives, optimized via reinforcement learning. We introduce a new bidirectional SSM architecture that seamlessly transitions from bidirectional context processing to causal generation. Experimental evaluations demonstrate that Birdie markedly improves performance on retrieval-intensive tasks such as multi-number phone book lookup, long paragraph question-answering, and infilling. This narrows the performance gap with Transformers, while retaining computational efficiency. Our findings highlight the importance of training procedures in leveraging the fixed-state capacity of SSMs, offering a new direction to advance their capabilities. All code and pre-trained models are available at https://www.github.com/samblouir/birdie, with support for JAX and PyTorch.
Problem

Research questions and friction points this paper is trying to address.

State Space Models
Long-term Memory Tasks
Transformer
Innovation

Methods, ideas, or system contributions that make the work stand out.

Bidirectional Information Processing
Reinforcement Learning Optimization
State Space Models (SSMs)
๐Ÿ”Ž Similar Papers
No similar papers found.
Sam Blouir
Sam Blouir
Department of Computer Science, George Mason University, Fairfax, VA
J
Jimmy Smith
Stanford University, Stanford, CA; Liquid AI, Palo Alto, CA
A
Antonios Anastasopoulos
Department of Computer Science, George Mason University, Fairfax, VA; Archimedes AI, Athena RC, Athens, Greece
Amarda Shehu
Amarda Shehu
Professor of Computer Science, George Mason University
Computational BiologyArtificial IntelligenceRoboticsBiophysicsBioinformatics