🤖 AI Summary
This work addresses the quadratic growth in computational and memory costs of attention mechanisms when processing long input sequences. To overcome this limitation, the authors propose an efficient approximation method based on weighted small-scale coresets. By integrating a randomized pivoted Cholesky algorithm for sampling and optimal weighting, the approach achieves near-linear time complexity for attention computation—marking the first such result with provable theoretical guarantees—while preserving super-polynomial error decay. Implemented within PyTorch for efficient GPU acceleration, the method significantly reduces computational overhead in tasks including image generation, image classification, and key-value cache compression in language models, all while maintaining high accuracy.
📝 Abstract
We introduce WildCat, a high-accuracy, low-cost approach to compressing the attention mechanism in neural networks. While attention is a staple of modern network architectures, it is also notoriously expensive to deploy due to resource requirements that scale quadratically with the input sequence length $n$. WildCat avoids these quadratic costs by only attending over a small weighted coreset. Crucially, we select the coreset using a fast but spectrally-accurate subsampling algorithm -- randomly pivoted Cholesky -- and weight the elements optimally to minimise reconstruction error. Remarkably, given bounded inputs, WildCat approximates exact attention with super-polynomial $O(n^{-\sqrt{\log(\log(n))}})$ error decay while running in near-linear $O(n^{1+o(1)})$ time. In contrast, prior practical approximations either lack error guarantees or require quadratic runtime to guarantee such high fidelity. We couple this advance with a GPU-optimized PyTorch implementation and a suite of benchmark experiments demonstrating the benefits of WildCat for image generation, image classification, and language model KV cache compression.