π€ AI Summary
This work addresses the high computational cost and poor generation quality of conventional Masked Generative Models (MGMs) under limited sampling budgets, which stem from their use of fixed-depth bidirectional Transformers at every denoising step. To overcome these limitations, the authors propose Fixed-Point Masked Generative Models (FP-MGMs), which employ a shared-attention-layer fixed-point solver to enable adaptive computation depth. They further introduce a cross-step consistency loss to align latent representations across adjacent denoising steps and incorporate a three-state reuse (3SR) mechanism to differentially handle token states. The framework allows efficient fine-tuning of pretrained MGMs into FP-MGMs. Experiments demonstrate significant improvements: on OpenWebText, FP-MGMs reduce parameters by 38.8%, training time by 11.5%, and memory usage by 16.9%, while perplexity drops dramatically from 830.8 to 101.8; on ImageNette, training time and memory are reduced by 48.6% and 50.7%, respectively, with consistently improved FID scores across all sampling budgets.
π Abstract
Masked Generative Models (MGMs) enable parallel decoding and achieve strong performance across modalities, but require full-sequence bidirectional transformers at every step, making training costly and degrading quality under low sampling budgets. Existing work improves efficiency via better samplers or cheaper fixed-depth denoisers, but they still allocate a fixed amount of denoiser computation to each refinement step. We introduce Fixed-Point Masked Generative Models (FP-MGMs), which replace part of the denoiser with a fixed-point solver over shared attention layers to enable adaptive depth with fewer parameters. To make it more effective for masked generation, we first introduce a cross-step consistency loss, which aligns hidden representations at neighboring denoising steps and, second, three-state reuse (3SR) which warm-starts the solver using the previous solution by treating differently unchanged, still-masked, and newly revealed tokens respectively. Together, these components define our complete training-to-inference framework for fixed-point masked generation, \emph{CoFRe}. We also show that pre-trained MGMs can be converted into FP-MGMs with short fine-tuning, avoiding full retraining. Across modalities, CoFRe improves the quality and cost trade-off. On OpenWebText, CoFRe reduces parameters by 38.8\%, training time by 11.5\%, and VRAM by 16.9\%, while improving generative perplexity from 830.8 to 101.8 at a budget of $96$ transformer-block forward passes, compared to MDLM. In ImageNette, CoFRe reduces training time by 48.6\% and VRAM by 50.7\%, while improving FID in all sample budgets tested. Overall, CoFRe offers a practical framework for cheaper training and stronger low-budget masked generation.