🤖 AI Summary
This work addresses the problem of data deletion in deep learning—efficiently predicting how a model’s behavior changes after removing a subset of its training data. The authors propose a novel approach grounded in a stability assumption, introducing high-order complex directional derivatives into locally sketched arithmetic circuits for the first time and leveraging forward-mode automatic differentiation for efficient computation. Their method incurs only a poly(1/ε) overhead over standard training during a one-time precomputation phase, requires storage equivalent to poly(1/ε) model copies, and achieves prediction latency only poly(1/ε) times slower than standard inference. Crucially, the prediction error vanishes as ε approaches zero, and the framework is compatible with state-of-the-art, high-performance AI models.
📝 Abstract
How does the choice of training data influence an AI model? This question is of central importance to interpretability, privacy, and basic science. At its core is the data deletion problem: after a reasonable amount of precomputation, quickly predict how the model would behave in a given situation if a given subset of training data had been excluded from the learning algorithm.
We present a data deletion scheme capable of predicting model outputs with vanishing error $\varepsilon$ in the deep learning setting. Our precomputation and prediction algorithms are only $\mathrm{poly}(1/\varepsilon)$ factors slower than regular training and inference, respectively. The storage requirements are those of $\mathrm{poly}(1/\varepsilon)$ models.
Our proof is based on an assumption that we call "stability." In contrast to the assumptions made by prior work, stability appears to be fully compatible with learning powerful AI models. In support of this, we show that stability is satisfied in a minimal set of experiments with microgpt. Our code is available at https://github.com/SamSpo1/microgpt-sketch.
At a technical level, our work is based on a new method for locally sketching an arithmetic circuit by computing higher-order derivatives in random complex directions. Forward-mode automatic differentiation allows cheap computation of these derivatives.