π€ AI Summary
SoftMax struggles to simultaneously achieve sparsity and multimodality in cross-modal attention, limiting its ability to capture multiple salient, semantically distinct alignments. To address this, we propose MultiMaxβa differentiable piecewise function that employs an input-range-aware adaptive scaling mechanism to dynamically redistribute attention probabilities, explicitly preserving multiple prominent peaks while suppressing irrelevant responses. MultiMax is the first method to jointly model sparsity and multimodal structure, overcoming fundamental expressivity limitations of SoftMax and its variants. It integrates seamlessly into standard training pipelines as a plug-and-play replacement. Extensive experiments on image classification, language modeling, and machine translation demonstrate consistent improvements in both task performance and attention interpretability. MultiMax effectively suppresses noisy attention responses, validating that multimodal attention yields substantive gains for cross-modal understanding.
π Abstract
SoftMax is a ubiquitous ingredient of modern machine learning algorithms. It maps an input vector onto a probability simplex and reweights the input by concentrating the probability mass at large entries. Yet, as a smooth approximation to the Argmax function, a significant amount of probability mass is distributed to other, residual entries, leading to poor interpretability and noise. Although sparsity can be achieved by a family of SoftMax variants, they often require an alternative loss function and do not preserve multi-modality. We show that this trade-off between multi-modality and sparsity limits the expressivity of SoftMax as well as its variants. We provide a solution to this tension between objectives by proposing a piece-wise differentiable function, termed MultiMax, which adaptively modulates the output distribution according to input entry range. Through comprehensive analysis and evaluation, we show that MultiMax successfully produces a distribution that supresses irrelevant entries while preserving multimodality, with benefits in image classification, language modeling and machine translation. The code is available at https://github.com/ZhouYuxuanYX/MultiMax.