๐ค AI Summary
This work addresses the discretization gapโi.e., the mismatch between training and inferenceโin logic gate networks, which typically employ soft mixing during training and hard selection at inference time. The authors systematically analyze behavioral differences across four training strategies and propose CAGE (Confidence-Adaptive Gradient Estimation), a method that decouples forward computation from stochasticity to maintain alignment in the forward pass while adaptively adjusting gradient estimation. By integrating Hard-ST with CAGE, the proposed framework achieves zero selection gap across all temperature settings without suffering from accuracy collapse. Empirical results demonstrate that the approach attains over 98% accuracy on MNIST and exceeds 58% on CIFAR-10, substantially outperforming baseline methods such as Gumbel-ST.
๐ Abstract
In neural network models, soft mixtures of fixed candidate components (e.g., logic gates and sub-networks) are often used during training for stable optimization, while hard selection is typically used at inference. This raises questions about training-inference mismatch. We analyze this gap by separating forward-pass computation (hard selection vs. soft mixture) from stochasticity (with vs. without Gumbel noise). Using logic gate networks as a testbed, we observe distinct behaviors across four methods: Hard-ST achieves zero selection gap by construction; Gumbel-ST achieves near-zero gap when training succeeds but suffers accuracy collapse at low temperatures; Soft-Mix achieves small gap only at low temperature via weight concentration; and Soft-Gumbel exhibits large gaps despite Gumbel noise, confirming that noise alone does not reduce the gap. We propose CAGE (Confidence-Adaptive Gradient Estimation) to maintain gradient flow while preserving forward alignment. On logic gate networks, Hard-ST with CAGE achieves over 98% accuracy on MNIST and over 58% on CIFAR-10, both with zero selection gap across all temperatures, while Gumbel-ST without CAGE suffers a 47-point accuracy collapse.