🤖 AI Summary
To address the challenge of directly optimizing the Jeffreys divergence in probabilistic density estimation, this paper proposes a joint training framework that co-optimizes normalizing flows (NFs) and energy-based models (EBMs). A learnable surrogate model is introduced to approximate the reverse KL term, enabling stable, differentiable estimation of the Jeffreys divergence. The method employs constrained optimization to dynamically balance forward and reverse KL objectives, circumventing the instability inherent in adversarial training while preserving the exact sampling capability of NFs and the flexible expressivity of EBMs. Experiments on density estimation and image generation demonstrate significantly improved training stability and superior fidelity in capturing multimodal distributional structures. This approach establishes a novel paradigm for simulation-based inference and other downstream applications requiring robust, gradient-based divergence minimization.
📝 Abstract
Many tasks in machine learning can be described as or reduced to learning a probability distribution given a finite set of samples. A common approach is to minimize a statistical divergence between the (empirical) data distribution and a parameterized distribution, e.g., a normalizing flow (NF) or an energy-based model (EBM). In this context, the forward KL divergence is a ubiquitous due to its tractability, though its asymmetry may prevent capturing some properties of the target distribution. Symmetric alternatives involve brittle min-max formulations and adversarial training (e.g., generative adversarial networks) or evaluating the reverse KL divergence, as is the case for the symmetric Jeffreys divergence, which is challenging to compute from samples. This work sets out to develop a new approach to minimize the Jeffreys divergence. To do so, it uses a proxy model whose goal is not only to fit the data, but also to assist in optimizing the Jeffreys divergence of the main model. This joint training task is formulated as a constrained optimization problem to obtain a practical algorithm that adapts the models priorities throughout training. We illustrate how this framework can be used to combine the advantages of NFs and EBMs in tasks such as density estimation, image generation, and simulation-based inference.