🤖 AI Summary
This work addresses the insufficient modeling of uncertainty and long-range dependencies in sequence modeling by introducing an explicit structured memory mechanism. The proposed approach writes observed evidence into memory via Bayesian filtering and generates predictive distributions through query-based retrieval, thereby incorporating Bayesian uncertainty modeling directly into the sequence layer design for the first time. The framework unifies several sub-quadratic recurrent models—including Mamba-2/SSD, DeltaNet, and linear attention—and establishes their theoretical connection to covariance dynamics. Experimental results demonstrate that the model significantly enhances robustness and retrieval performance on long-context tasks such as MQAR and RULER, validating the critical role of covariance recovery in improving generalization.
📝 Abstract
We introduce the design-model framework: a way to derive efficient recurrent sequence maps from explicit assumptions about memory. A design model writes evidence into memory by exact Bayesian filtering; a query-dependent readout produces a predictive distribution whose mean is the layer output. In our linear-Gaussian instantiation, the \emph{Bayesian Layer} propagates both a mean and a covariance: the covariance tracks uncertainty over stored associations, steering writes toward uncertain directions, attenuating gains as evidence accumulates, and preserving confident memories. The same framework unifies several sub-quadratic recurrences. Linear attention, GLA, and Mamba-2/SSD are exact filters under one design model, whereas DeltaNet and related Delta-rule models arise as covariance-reset reductions under another. Restoring the covariance yields closed-form predictions for retrieval dynamics, verified empirically, and improves robustness beyond the training regime across controlled collision studies, learned associative recall, and the Zoology MQAR benchmark; distilling Bayesian Layers into a pretrained 340M Gated DeltaNet improves RULER long-context retrieval at matched compute.