🤖 AI Summary
This work addresses the memory and computational inefficiencies of traditional Gaussian Mixture Models (GMMs) when trained on large-scale datasets, which hinder single-GPU scalability. The authors propose Flash-GMM, the first method to train GMMs without explicitly materializing the responsibility matrix. By integrating fused Triton kernels, Flash-GMM performs soft clustering computations efficiently within a single GPU pass. This approach dramatically improves both memory usage and computational throughput—achieving up to 20× speedup over existing methods and enabling training on datasets over 100× larger. Furthermore, as a plug-in replacement for k-means in approximate nearest neighbor search, Flash-GMM boosts recall@10 by 2–12 percentage points or reduces distance computations by up to 1.7× at equivalent recall levels.
📝 Abstract
We present \textbf{Flash-GMM}, a fused Triton kernel for efficient computation of Gaussian Mixture Models (GMMs) over large-scale data in a single GPU pass. By eliminating the need to materialize the full responsibility matrix in GPU memory, Flash-GMM achieves a \textbf{20$\times$} speedup over existing implementations and enables training on datasets more than \textbf{100$\times$} larger than previously feasible on one device. To demonstrate its impact, we integrate Flash-GMM into the IVF coarse quantizer for approximate nearest-neighbor (ANN) search. We show that soft GMM clustering is now a viable drop-in replacement for $k$-means, and that GMM responsibilities can be leveraged to assign border vectors to multiple clusters. Our approach reaches fixed recall targets with up to $1.7\times$ fewer distance computations, or equivalently, yields $+2$--$12$ recall@10 at matched computational cost. We release the kernel as an open-source project.