CLAX: Fast and Flexible Neural Click Models in JAX

📅 2025-11-05
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Existing probabilistic graphical model (PGM)-based click models rely heavily on the EM algorithm, hindering seamless integration with modern deep learning frameworks and impeding simultaneous optimization of interpretability and training efficiency. Method: This paper introduces the first JAX-based neural click modeling library, systematically replacing EM with gradient-based optimization for classical click models—enabling end-to-end, differentiable training. Leveraging JAX’s automatic differentiation and XLA compilation, the framework supports flexible integration of embedding layers, deep neural networks, and user-defined modules. Results: On the Baidu-ULTR dataset (over one billion sessions), our implementation achieves single-GPU training in approximately two hours—accelerating training by several orders of magnitude over EM. We faithfully reproduce and extend ten classical click models. Our core contribution is the first stable realization of gradient-optimized PGM-based click models, achieving efficient, modular, large-scale training while preserving structural interpretability.

Technology Category

Application Category

📝 Abstract
CLAX is a JAX-based library that implements classic click models using modern gradient-based optimization. While neural click models have emerged over the past decade, complex click models based on probabilistic graphical models (PGMs) have not systematically adopted gradient-based optimization, preventing practitioners from leveraging modern deep learning frameworks while preserving the interpretability of classic models. CLAX addresses this gap by replacing EM-based optimization with direct gradient-based optimization in a numerically stable manner. The framework's modular design enables the integration of any component, from embeddings and deep networks to custom modules, into classic click models for end-to-end optimization. We demonstrate CLAX's efficiency by running experiments on the full Baidu-ULTR dataset comprising over a billion user sessions in $approx$ 2 hours on a single GPU, orders of magnitude faster than traditional EM approaches. CLAX implements ten classic click models, serving both industry practitioners seeking to understand user behavior and improve ranking performance at scale and researchers developing new click models. CLAX is available at: https://github.com/philipphager/clax
Problem

Research questions and friction points this paper is trying to address.

Replacing EM optimization with gradient-based methods for click models
Integrating modern deep learning components into interpretable classic models
Enabling efficient billion-scale user session analysis on single GPU
Innovation

Methods, ideas, or system contributions that make the work stand out.

Replaces EM optimization with gradient-based methods
Enables modular integration of neural network components
Achieves billion-scale training in hours on single GPU
🔎 Similar Papers
No similar papers found.
Philipp Hager
Philipp Hager
University of Amsterdam
Learning to RankInformation RetrievalCounterfactual LearningMachine Learning
O
O. Zoeter
Booking.com, Amsterdam, The Netherlands
M
M. D. Rijke
University of Amsterdam, Amsterdam, The Netherlands