🤖 AI Summary
Medical imaging models often exhibit poor cross-center generalization due to domain shifts arising from heterogeneous imaging devices, acquisition protocols, and patient populations. Moreover, hospitals deploy diverse model architectures (e.g., MLPs, CNNs, GNNs), complicating collaborative optimization. To address this, we propose the first unified graph learning framework: heterogeneous model parameters are encoded as nodes, and architectural relationships are modeled as edges to construct a shared graph-structured parameter space. We introduce a unified Graph Neural Network (uGNN) that enables cross-architecture parameter sharing and global knowledge transfer, effectively decoupling architecture-specific structural biases from generalizable representations. Extensive experiments on multi-source benchmarks—including MorphoMNIST and MedMNIST variants (PneumoniaMNIST, BreastMNIST)—demonstrate substantial improvements in out-of-distribution generalization (i.e., training on independent distributions, testing on mixed distributions) and robustness under large-scale distribution shifts.
📝 Abstract
Deep learning models often struggle to maintain generalizability in medical imaging, particularly under domain-fracture scenarios where distribution shifts arise from varying imaging techniques, acquisition protocols, patient populations, demographics, and equipment. In practice, each hospital may need to train distinct models - differing in learning task, width, and depth - to match local data. For example, one hospital may use Euclidean architectures such as MLPs and CNNs for tabular or grid-like image data, while another may require non-Euclidean architectures such as graph neural networks (GNNs) for irregular data like brain connectomes. How to train such heterogeneous models coherently across datasets, while enhancing each model's generalizability, remains an open problem. We propose unified learning, a new paradigm that encodes each model into a graph representation, enabling unification in a shared graph learning space. A GNN then guides optimization of these unified models. By decoupling parameters of individual models and controlling them through a unified GNN (uGNN), our method supports parameter sharing and knowledge transfer across varying architectures (MLPs, CNNs, GNNs) and distributions, improving generalizability. Evaluations on MorphoMNIST and two MedMNIST benchmarks - PneumoniaMNIST and BreastMNIST - show that unified learning boosts performance when models are trained on unique distributions and tested on mixed ones, demonstrating strong robustness to unseen data with large distribution shifts. Code and benchmarks: https://github.com/basiralab/uGNN