🤖 AI Summary
Existing spiking neural network (SNN) simulators struggle to simultaneously achieve usability, flexibility, and computational efficiency, hindering rapid prototyping of biologically inspired models and neuromorphic hardware. To address this, we propose the first SNN framework deeply integrated with JAX’s automatic differentiation, just-in-time (JIT) compilation, and functional programming paradigm—unifying PyTorch-level API simplicity with JAX-level execution performance. Our framework enables hardware-aware modeling, customizable spiking neuron dynamics (e.g., leaky integrate-and-fire), and end-to-end differentiable training. Evaluated on standard SNN benchmarks, it achieves 3–10× higher simulation throughput than Brian2, SpykeTorch, and Sinabs, while accelerating training convergence and reducing memory footprint. These advances significantly bridge the gap between algorithmic innovation and hardware deployment, enabling co-optimization across the SNN design stack.
📝 Abstract
Spiking Neural Networks (SNNs) simulators are essential tools to prototype biologically inspired models and neuromorphic hardware architectures and predict their performance. For such a tool, ease of use and flexibility are critical, but so is simulation speed especially given the complexity inherent to simulating SNN. Here, we present SNNAX, a JAX-based framework for simulating and training such models with PyTorch-like intuitiveness and JAX-like execution speed. SNNAX models are easily extended and customized to fit the desired model specifications and target neuromorphic hardware. Additionally, SNNAX offers key features for optimizing the training and deployment of SNNs such as flexible automatic differentiation and just-in-time compilation. We evaluate and compare SNNAX to other commonly used machine learning (ML) frameworks used for programming SNNs. We provide key performance metrics, best practices, documented examples for simulating SNNs in SNNAX, and implement several benchmarks used in the literature.