Revisiting Nearest Neighbor for Tabular Data: A Deep Tabular Baseline Two Decades Later

📅 2024-07-03
📈 Citations: 5
Influential: 1
📄 PDF
🤖 AI Summary
This work addresses the fundamental limitation of classical k-nearest neighbors (k-NN) in tabular data modeling—its non-differentiability, which precludes end-to-end optimization. We propose the first deep differentiable Neighborhood Component Analysis (NCA) framework. Methodologically, we integrate NCA directly into deep neural network representation learning, eliminating hand-crafted feature engineering and conventional dimensionality reduction; the framework incorporates stochastic augmentation, an adaptive loss function, and ensemble-based prediction for fully end-to-end training. Contributions: (i) We empirically establish, for the first time, that a pure NCA-based model achieves strong competitive performance on standard tabular benchmarks; (ii) augmented with deep representations and training stochasticity, our method attains state-of-the-art results across 300 classification and regression benchmarks—significantly outperforming leading deep tabular models (e.g., TabNet, SAINT) and matching the performance of advanced gradient-boosted tree methods (e.g., CatBoost). The code is publicly available.

Technology Category

Application Category

📝 Abstract
The widespread enthusiasm for deep learning has recently expanded into the domain of tabular data. Recognizing that the advancement in deep tabular methods is often inspired by classical methods, e.g., integration of nearest neighbors into neural networks, we investigate whether these classical methods can be revitalized with modern techniques. We revisit a differentiable version of $K$-nearest neighbors (KNN) -- Neighbourhood Components Analysis (NCA) -- originally designed to learn a linear projection to capture semantic similarities between instances, and seek to gradually add modern deep learning techniques on top. Surprisingly, our implementation of NCA using SGD and without dimensionality reduction already achieves decent performance on tabular data, in contrast to the results of using existing toolboxes like scikit-learn. Further equipping NCA with deep representations and additional training stochasticity significantly enhances its capability, being on par with the leading tree-based method CatBoost and outperforming existing deep tabular models in both classification and regression tasks on 300 datasets. We conclude our paper by analyzing the factors behind these improvements, including loss functions, prediction strategies, and deep architectures. The code is available at https://github.com/qile2000/LAMDA-TALENT.
Problem

Research questions and friction points this paper is trying to address.

Revitalizing classical tabular data methods with modern deep learning techniques.
Enhancing Neighbourhood Components Analysis (NCA) using deep representations and stochasticity.
Comparing NCA performance with leading tree-based and deep tabular models.
Innovation

Methods, ideas, or system contributions that make the work stand out.

Differentiable KNN with modern deep learning
Enhanced NCA using deep representations
Outperforms CatBoost and deep tabular models
🔎 Similar Papers
No similar papers found.
Han-Jia Ye
Han-Jia Ye
Nanjing University
Machine LearningData MiningMetric LearningMeta-Learning
H
Huai-Hong Yin
School of Artificial Intelligence, Nanjing University, China; National Key Laboratory for Novel Software Technology, Nanjing University
De-Chuan Zhan
De-Chuan Zhan
Nanjing University, China
Machine LearningData Mining
W
Wei-Lun Chao