๐ค AI Summary
This work addresses the inconsistency in binary (yes/no) responses of medical visionโlanguage models when confronted with semantically equivalent but linguistically diverse formulations of the same clinical question. To mitigate this issue, the study introduces a novel approach that integrates mechanistic interpretability analysis with low-rank adaptation (LoRA) fine-tuning. Specifically, sparse autoencoders (Gemma Scope 2) are employed to identify task-critical neurons, and a paraphrase consistency loss is incorporated during training to encourage stable predictions across rephrasings while avoiding mode collapse. Experiments on the MIMIC-CXR and PadChest datasets demonstrate substantial improvements: answer flip rates decrease from 14.6% to 4.4% and from 13.6% to 7.8%, respectively, and the average logit difference between paraphrased queries is reduced by 79.5%, all while maintaining or slightly improving overall accuracy.
๐ Abstract
Medical Vision-Language Models can give different yes or no answers to rephrasings of the same clinical question. We study this in MedGemma-4B using PSF-Med Sadanandan and Behzadan (2025), which provides paraphrase pairs for systematic consistency evaluation on medical VQA. On MIMIC-CXR binary questions (n = 158), the baseline flip rate is 14.6% and mean margin difference is 1.63 logits. We validate that Gemma Scope 2 Sparse Autoencoders (SAEs) transfer to MedGemma activations, achieving R2 ~= 0.997 on both medical and general text (n = 100 prompts each, p < 0.001 for exceeding a 0.95 threshold). We then fine-tune Low-Rank Adaptation (LoRA) adapters with a combined loss that balances paraphrase consistency with answer accuracy. This combined approach prevents mode collapse that occurs with pure consistency training while reducing flip rate from 14.6% to 4.4% (p = 0.002, two-proportion z-test) and margin difference from 1.63 to 0.33 (79.5% reduction). Accuracy remains stable at 84.2% baseline versus 82.3% after training (-1.9pp, not significant). On PadChest Balanced (n = 250), flip rate drops from 13.6% to 7.8%, mean margin difference drops from 1.08 to 0.35 (67.9% reduction), and accuracy increases from 66.4% to 69.4%. A layer-range ablation shows that early layers reduce margin differences more than mechanistically selected middle layers.