🤖 AI Summary
Modern deep learning training pipelines involve heterogeneous components—such as multi-task heads, distillation objectives, or multimodal encoders—yet lack a unified formalism for modeling inter-component dependencies and enabling holistic optimization.
Method: We propose *Learning Diagrams*, the first categorical framework for declaratively specifying training workflows as composable, compiler-ready graphical structures. It enables constraint-driven component composition and automatic synthesis of joint loss functions that enforce predictive consistency across submodels.
Contribution/Results: Learning Diagrams unifies diverse paradigms—including few-shot multi-task learning, knowledge distillation, and multimodal learning—under a single abstraction, supporting dynamic in-training and post-hoc reconfiguration. Integrated with PyTorch and Flux.jl, our open-source implementation includes a graph compiler and demonstrates cross-paradigm expressivity and compositional modeling efficacy across canonical benchmarks. Empirical results show substantial improvements in systematicity, modularity, and interpretability of complex model construction.
📝 Abstract
Motivated by deep learning regimes with multiple interacting yet distinct model components, we introduce learning diagrams, graphical depictions of training setups that capture parameterized learning as data rather than code. A learning diagram compiles to a unique loss function on which component models are trained. The result of training on this loss is a collection of models whose predictions ``agree"with one another. We show that a number of popular learning setups such as few-shot multi-task learning, knowledge distillation, and multi-modal learning can be depicted as learning diagrams. We further implement learning diagrams in a library that allows users to build diagrams of PyTorch and Flux.jl models. By implementing some classic machine learning use cases, we demonstrate how learning diagrams allow practitioners to build complicated models as compositions of smaller components, identify relationships between workflows, and manipulate models during or after training. Leveraging a category theoretic framework, we introduce a rigorous semantics for learning diagrams that puts such operations on a firm mathematical foundation.