🤖 AI Summary
Existing probabilistic graphical model (PGM)-based click models rely heavily on the EM algorithm, hindering seamless integration with modern deep learning frameworks and impeding simultaneous optimization of interpretability and training efficiency.
Method: This paper introduces the first JAX-based neural click modeling library, systematically replacing EM with gradient-based optimization for classical click models—enabling end-to-end, differentiable training. Leveraging JAX’s automatic differentiation and XLA compilation, the framework supports flexible integration of embedding layers, deep neural networks, and user-defined modules.
Results: On the Baidu-ULTR dataset (over one billion sessions), our implementation achieves single-GPU training in approximately two hours—accelerating training by several orders of magnitude over EM. We faithfully reproduce and extend ten classical click models. Our core contribution is the first stable realization of gradient-optimized PGM-based click models, achieving efficient, modular, large-scale training while preserving structural interpretability.
📝 Abstract
CLAX is a JAX-based library that implements classic click models using modern gradient-based optimization. While neural click models have emerged over the past decade, complex click models based on probabilistic graphical models (PGMs) have not systematically adopted gradient-based optimization, preventing practitioners from leveraging modern deep learning frameworks while preserving the interpretability of classic models. CLAX addresses this gap by replacing EM-based optimization with direct gradient-based optimization in a numerically stable manner. The framework's modular design enables the integration of any component, from embeddings and deep networks to custom modules, into classic click models for end-to-end optimization. We demonstrate CLAX's efficiency by running experiments on the full Baidu-ULTR dataset comprising over a billion user sessions in $approx$ 2 hours on a single GPU, orders of magnitude faster than traditional EM approaches. CLAX implements ten classic click models, serving both industry practitioners seeking to understand user behavior and improve ranking performance at scale and researchers developing new click models. CLAX is available at: https://github.com/philipphager/clax