Demystifying Language Model Forgetting with Low-rank Example Associations

📅 2024-06-20
📈 Citations: 2
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses upstream data forgetting during large language model (LLM) fine-tuning. We empirically uncover that forgetting patterns exhibit a low-rank structure across the task–sample dimension, indicating a concise, learnable dependency between new tasks and highly forgotten samples. First, we provide the first empirical validation of the intrinsic low-rank nature of the forgetting matrix. Leveraging this insight, we propose a semantic-agnostic matrix completion approach for forgetting prediction—bypassing costly language model–based semantic modeling—and achieve significantly higher accuracy in identifying highly forgotten samples than state-of-the-art LM-based methods. Further, we design a replay-weighted fine-tuning strategy that selectively re-trains on predicted high-forgetting samples. Experiments demonstrate that our method substantially reduces actual forgetting rates (p < 0.01), offering a novel paradigm for controllable fine-tuning and continual learning.

Technology Category

Application Category

📝 Abstract
Large Language models (LLMs) suffer from forgetting of upstream data when fine-tuned. Despite efforts on mitigating forgetting, few have investigated whether, and how forgotten upstream examples are dependent on newly learned tasks. Insights on such dependencies enable efficient and targeted mitigation of forgetting. In this paper, we empirically analyze forgetting that occurs in $N$ upstream examples of language modeling or instruction-tuning after fine-tuning LLMs on one of $M$ new tasks, visualized in $M imes N$ matrices. We show that the matrices are often well-approximated with low-rank matrices, indicating the dominance of simple associations between the learned tasks and forgotten upstream examples. Leveraging the analysis, we predict forgetting of upstream examples when fine-tuning on unseen tasks with matrix completion over the empirical associations. This enables fast identification of most forgotten examples without expensive inference on the entire upstream data. The approach, despite simplicity, outperforms prior approaches that learn semantic relationships of learned tasks and upstream examples with LMs for predicting forgetting. We demonstrate the practical utility of our analysis by showing statistically significantly reduced forgetting as we upweight predicted examples for replay at fine-tuning. Project page: https://inklab.usc.edu/lm-forgetting-prediction/
Problem

Research questions and friction points this paper is trying to address.

Analyze forgetting in large language models
Predict forgetting with low-rank matrix completion
Mitigate forgetting by upweighting predicted examples
Innovation

Methods, ideas, or system contributions that make the work stand out.

Low-rank matrix approximation
Matrix completion prediction
Upweighted example replay
🔎 Similar Papers
No similar papers found.