🤖 AI Summary
This work addresses the limitations of masked diffusion language models, which suffer from low training efficiency and inferior performance compared to autoregressive counterparts. The authors propose a block-wise causal diffusion architecture that enforces strict causality and permutation equivariance, enabling joint training on both standard and randomly permuted word orders. This design allows parallel computation of multi-step denoising conditional probabilities within a single forward pass. By integrating progressive permutation training with a cross-step parallel generation strategy, the method significantly accelerates inference while preserving global coherence. Experiments demonstrate that the model achieves state-of-the-art performance on standard language modeling benchmarks, drastically reduces the number of required training steps, and effectively narrows the performance gap between parallel generation and autoregressive decoding.
📝 Abstract
Masked diffusion models (MDMs) have emerged as a promising approach for language modeling, yet they face a performance gap compared to autoregressive models (ARMs) and require more training iterations. In this work, we present the Auto-Regressive Masked Diffusion (ARMD) model, an architecture designed to close this gap by unifying the training efficiency of autoregressive models with the parallel generation capabilities of diffusion-based models. Our key insight is to reframe the masked diffusion process as a block-wise causal model. This perspective allows us to design a strictly causal, permutation-equivariant architecture that computes all conditional probabilities across multiple denoising steps in a single, parallel forward pass. The resulting architecture supports efficient, autoregressive-style decoding and a progressive permutation training scheme, allowing the model to learn both canonical left-to-right and random token orderings. Leveraging this flexibility, we introduce a novel strided parallel generation strategy that accelerates inference by generating tokens in parallel streams while maintaining global coherence. Empirical results demonstrate that ARMD achieves state-of-the-art performance on standard language modeling benchmarks, outperforming established diffusion baselines while requiring significantly fewer training steps. Furthermore, it establishes a new benchmark for parallel text generation, effectively bridging the performance gap between parallel and sequential decoding.