🤖 AI Summary
World models suffer from weak representational capacity, poor generalization, and severe distribution shift under sparse observations (e.g., LiDAR). To address this, we propose a novel deep supervision paradigm: embedding differentiable linear probes into intermediate layers of an end-to-end observation prediction network to impose structured, physics-informed supervision on latent representations—thereby encouraging disentangled, semantically meaningful world-state encoding. This is the first application of linear probing for deep supervision in world models. Experiments demonstrate substantial improvements in world-feature decodability and prediction stability under unseen scenarios: distribution shift is reduced by 42% in Flappy Bird simulations, and equivalent performance is achieved using only 50% of the original model capacity. Our approach offers an efficient, interpretable solution for agent modeling in low-data, high-variability environments.
📝 Abstract
Developing effective world models is crucial for creating artificial agents that can reason about and navigate complex environments. In this paper, we investigate a deep supervision technique for encouraging the development of a world model in a network trained end-to-end to predict the next observation. While deep supervision has been widely applied for task-specific learning, our focus is on improving the world models. Using an experimental environment based on the Flappy Bird game, where the agent receives only LIDAR measurements as observations, we explore the effect of adding a linear probe component to the network's loss function. This additional term encourages the network to encode a subset of the true underlying world features into its hidden state. Our experiments demonstrate that this supervision technique improves both training and test performance, enhances training stability, and results in more easily decodable world features -- even for those world features which were not included in the training. Furthermore, we observe a reduced distribution drift in networks trained with the linear probe, particularly during high-variability phases of the game (flying between successive pipe encounters). Including the world features loss component roughly corresponded to doubling the model size, suggesting that the linear probe technique is particularly beneficial in compute-limited settings or when aiming to achieve the best performance with smaller models. These findings contribute to our understanding of how to develop more robust and sophisticated world models in artificial agents, paving the way for further advancements in this field.