Physics ∩ ML

a virtual hub at the interface of theoretical physics and deep learning.

08 Sep 2021

JAX MD: A Framework for Differentiable Atomistic Physics

Sam Schoenholz, Google Brain 12:00 ET

I will talk about JAX MD, a software package for performing differentiable physics simulations with a focus on molecular dynamics. JAX MD includes a number of physics simulation environments, as well as interaction potentials and neural networks that can be integrated into these environments without writing any additional code. Since the simulations themselves are differentiable functions, entire trajectories can be differentiated to perform meta-optimization. These features are built on primitive operations, such as spatial partitioning, that allow simulations to scale to hundreds-of-thousands of particles on a single GPU. My talk will include an introduction to the JAX software package If you are interested in trying out JAX MD, it is available at