🤖 AI Summary
Current robot manipulation policies exhibit poor generalization under unseen execution variations, largely because existing attention mechanisms fail to model temporal structures—such as failure-recovery patterns—embedded in demonstrations. To address this, we propose State-Transition Attention, a novel attention mechanism explicitly designed to learn state evolution dynamics; it is integrated into a Transformer architecture and augmented with temporal masking during training to strengthen sequential reasoning over historical context. We benchmark our approach against visual dynamic masking, Temporal Convolutional Networks (TCNs), and Long Short-Term Memory (LSTM) networks. In simulation, our method significantly outperforms standard cross-attention as well as TCN and LSTM baselines, achieving over a two-fold improvement in high-precision manipulation tasks. These results empirically validate that explicit modeling of temporal structure is critical for enhancing policy robustness and generalization.
📝 Abstract
Learning robotic manipulation policies through supervised learning from demonstrations remains challenging when policies encounter execution variations not explicitly covered during training. While incorporating historical context through attention mechanisms can improve robustness, standard approaches process all past states in a sequence without explicitly modeling the temporal structure that demonstrations may include, such as failure and recovery patterns. We propose a Cross-State Transition Attention Transformer that employs a novel State Transition Attention (STA) mechanism to modulate standard attention weights based on learned state evolution patterns, enabling policies to better adapt their behavior based on execution history. Our approach combines this structured attention with temporal masking during training, where visual information is randomly removed from recent timesteps to encourage temporal reasoning from historical context. Evaluation in simulation shows that STA consistently outperforms standard cross-attention and temporal modeling approaches like TCN and LSTM networks across all tasks, achieving more than 2x improvement over cross-attention on precision-critical tasks.