π€ AI Summary
To address the prohibitive computational complexity, terabyte-scale memory requirements, and single-GPU compute/memory bottlenecks in fully homomorphic encryption (FHE)-based inference for large language models (e.g., Llama3-8B), this paper introduces Ceriumβthe first multi-GPU cooperative FHE inference framework. Its core innovations are: (1) millisecond-scale GPU bootstrapping (7.5 ms); (2) an FHE-optimized sparse polynomial intermediate representation (IR) coupled with communication-aware parallel compilation; and (3) a domain-specific language (DSL)-driven automatic high-performance GPU kernel generator with cross-GPU tensor-parallel runtime support. Experiments demonstrate that Cerium achieves the first end-to-end encrypted inference for BERT-Base and Llama3-8B. It outperforms hand-optimized libraries by 2.25Γ on smaller models and matches the FHE throughput of state-of-the-art ASIC accelerators (CraterLake). Notably, encrypted Llama3-8B inference completes in just 134 seconds.
π Abstract
Encrypted AI using fully homomorphic encryption (FHE) provides strong privacy guarantees; but its slow performance has limited practical deployment. Recent works proposed ASICs to accelerate FHE, but require expensive advanced manufacturing processes that constrain their accessibility. GPUs are a far more accessible platform, but achieving ASIC-level performance using GPUs has remained elusive. Furthermore, state-of-the-art approaches primarily focus on small models that fit comfortably within a single device. Supporting large models such as LLMs in FHE introduces a dramatic increase in computational complexity that requires optimized GPU kernels, along with managing terabyte-scale memory footprints that far exceed the capacity of a single GPU. This paper presents Cerium, a multi-GPU framework for FHE inference on large models. Cerium integrates a domain-specific language, an optimizing compiler, and a runtime system to automatically generate high-performance GPU kernels, manage terabyte-scale memory footprints, and parallelize computation across multiple GPUs. It introduces new IR constructs, compiler passes, sparse polynomial representations, memory-efficient data layouts, and communication-aware parallelization techniques that together enable encrypted inference for models ranging from small CNNs to Llama3-8B. We build Cerium on NVIDIA GPUs and demonstrate significant performance gains. For small models, Cerium outperforms expert-written hand-optimized GPU libraries by up to 2.25 times. Cerium achieves performance competitive with state-of-the-art FHE ASICs, outright matching prior FHE ASIC CraterLake. It is the first GPU system to execute bootstrapping in under 10 milliseconds, achieving 7.5 milliseconds, and is the first to demonstrate encrypted inference for BERT-Base and Llama3-8B in 8 seconds and 134 seconds, respectively.