Auto-Regressive Masked Diffusion Models

📅 2026-01-23
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses the limitations of masked diffusion language models, which suffer from low training efficiency and inferior performance compared to autoregressive counterparts. The authors propose a block-wise causal diffusion architecture that enforces strict causality and permutation equivariance, enabling joint training on both standard and randomly permuted word orders. This design allows parallel computation of multi-step denoising conditional probabilities within a single forward pass. By integrating progressive permutation training with a cross-step parallel generation strategy, the method significantly accelerates inference while preserving global coherence. Experiments demonstrate that the model achieves state-of-the-art performance on standard language modeling benchmarks, drastically reduces the number of required training steps, and effectively narrows the performance gap between parallel generation and autoregressive decoding.

Technology Category

Application Category

📝 Abstract
Masked diffusion models (MDMs) have emerged as a promising approach for language modeling, yet they face a performance gap compared to autoregressive models (ARMs) and require more training iterations. In this work, we present the Auto-Regressive Masked Diffusion (ARMD) model, an architecture designed to close this gap by unifying the training efficiency of autoregressive models with the parallel generation capabilities of diffusion-based models. Our key insight is to reframe the masked diffusion process as a block-wise causal model. This perspective allows us to design a strictly causal, permutation-equivariant architecture that computes all conditional probabilities across multiple denoising steps in a single, parallel forward pass. The resulting architecture supports efficient, autoregressive-style decoding and a progressive permutation training scheme, allowing the model to learn both canonical left-to-right and random token orderings. Leveraging this flexibility, we introduce a novel strided parallel generation strategy that accelerates inference by generating tokens in parallel streams while maintaining global coherence. Empirical results demonstrate that ARMD achieves state-of-the-art performance on standard language modeling benchmarks, outperforming established diffusion baselines while requiring significantly fewer training steps. Furthermore, it establishes a new benchmark for parallel text generation, effectively bridging the performance gap between parallel and sequential decoding.
Problem

Research questions and friction points this paper is trying to address.

masked diffusion models
autoregressive models
language modeling
parallel generation
training efficiency
Innovation

Methods, ideas, or system contributions that make the work stand out.

Auto-Regressive Masked Diffusion
Parallel Generation
Causal Architecture
Permutation-Equivariant Training
Strided Inference
🔎 Similar Papers
2022-09-02ACM Computing SurveysCitations: 1628
M
Mahdi Karami
School of Computer Science, University of Waterloo, ON, Canada & Google Research
Ali Ghodsi
Ali Ghodsi
UC Berkeley, Databricks
Big DataDistributed ComputingData ManagementNetworking