🤖 AI Summary
CUDA Graphs in PyTorch suffer from deployment challenges and high overhead—sometimes even yielding negative speedup—due to static graph constraints and redundant host-to-device parameter copies. This paper proposes a compiler-level, fully automatic optimization framework that requires no user code modification. First, it introduces a novel cost-benefit–driven dynamic graph selection mechanism that adaptively enables or bypasses graph capture based on runtime characteristics. Second, it eliminates redundant kernel parameter copying overhead—a previously unaddressed bottleneck. Third, it extends the scope of automatic static graph capture and reuse to support more complex ML workflows. The framework tightly integrates PyTorch 2’s compilation stack, CUDA Graphs’ hardware capabilities, DAG-structured program analysis, and runtime heuristic decision-making. Evaluated across diverse ML benchmarks, it consistently outperforms PyTorch 2, completely eliminates negative speedup, and delivers 1.3×–2.1× average end-to-end speedup.
📝 Abstract
CUDA Graphs -- a recent hardware feature introduced for NVIDIA GPUs -- aim to reduce CPU launch overhead by capturing and launching a series of GPU tasks (kernels) as a DAG. However, deploying CUDA Graphs faces several challenges today due to the static structure of a graph. It also incurs performance overhead due to data copy. In fact, we show a counter-intuitive result -- deploying CUDA Graphs hurts performance in many cases. We introduce PyGraph, a novel approach to automatically harness the power of CUDA Graphs within PyTorch2. Driven by three key observations, PyGraph embodies three novel optimizations: it enables wider deployment of CUDA Graphs, reduces GPU kernel parameter copy overheads, and selectively deploys CUDA Graphs based on a cost-benefit analysis. PyGraph seamlessly integrates with PyTorch2's compilation toolchain, enabling efficient use of CUDA Graphs without manual modifications to the code. We evaluate PyGraph across various machine learning benchmarks, demonstrating substantial performance improvements over PyTorch2.