Matryoshka Model Learning for Improved Elastic Student Models

📅 2025-05-29
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
To address the high retraining and deployment costs of ML models in industrial settings where service constraints dynamically change, this paper proposes MatTA—a novel teacher–TA–student three-tier knowledge distillation framework. MatTA introduces an intermediate Teacher-Augmented (TA) model that jointly inherits the teacher’s knowledge capacity and the student’s architectural compatibility, enabling multi-granularity, elastic student model extraction. Through parameter sharing, progressive pruning, and multi-objective loss optimization, MatTA generates multiple plug-and-play student models with tunable accuracy–cost trade-offs from a single training run. In production A/B testing, MatTA improves key system metrics by 20%. On GPT-2 Medium, it achieves relative accuracy gains of 24.3% on SAT Math and 10.1% on LAMBADA—demonstrating its effectiveness in adapting large language models to dynamic operational constraints.

Technology Category

Application Category

📝 Abstract
Industry-grade ML models are carefully designed to meet rapidly evolving serving constraints, which requires significant resources for model development. In this paper, we propose MatTA, a framework for training multiple accurate Student models using a novel Teacher-TA-Student recipe. TA models are larger versions of the Student models with higher capacity, and thus allow Student models to better relate to the Teacher model and also bring in more domain-specific expertise. Furthermore, multiple accurate Student models can be extracted from the TA model. Therefore, despite only one training run, our methodology provides multiple servable options to trade off accuracy for lower serving cost. We demonstrate the proposed method, MatTA, on proprietary datasets and models. Its practical efficacy is underscored by live A/B tests within a production ML system, demonstrating 20% improvement on a key metric. We also demonstrate our method on GPT-2 Medium, a public model, and achieve relative improvements of over 24% on SAT Math and over 10% on the LAMBADA benchmark.
Problem

Research questions and friction points this paper is trying to address.

Efficiently training multiple accurate Student models under constraints
Reducing serving costs while maintaining model accuracy
Improving domain-specific expertise transfer from Teacher to Student models
Innovation

Methods, ideas, or system contributions that make the work stand out.

Teacher-TA-Student recipe for model training
Extract multiple Student models from TA
Single training run enables servable options
🔎 Similar Papers
No similar papers found.
Chetan Verma
Chetan Verma
@google, @twitter, @ucsd, @iitm
Machine LearningLanguage ModelsRecommender Systems
A
A. Timmaraju
Google DeepMind
C
Cho Jui-Hsieh
Google
S
Suyash Damle
Google
N
Ngot Bui
Google
Y
Yang Zhang
Google
W
Wen Chen
Google
X
Xin Liu
Google
P
Prateek Jain
Google DeepMind
Inderjit S. Dhillon
Inderjit S. Dhillon
VP, Google Fellow at Google & Professor of Computer Science at UT Austin, Ex-VP at Amazon
Machine LearningDeep LearningLarge Language ModelsNumerical Linear AlgebraComputational