Learning diffusion at lightspeed

Paper & Code

Cover image Preview image

Diffusion regulates numerous natural processes and drives the dynamics of many successful generative models. Current models for learning diffusion terms from observational data often require complex bilevel optimization problems and primarily focus on modeling the drift component of the system.

We propose a new simple model, JKOnet*, which bypasses the complexity of existing architectures while presenting significantly enhanced representational capabilities: JKOnet* recovers the potential, interaction, and internal energy components of the underlying diffusion process. JKOnet* minimizes a simple quadratic loss and drastically outperforms other baselines in terms of sample efficiency, computational complexity, and accuracy. Additionally, JKOnet* provides a closed-form optimal solution for linearly parametrized functionals, and, when applied to predict the evolution of cellular processes from real-world data, it achieves state-of-the-art accuracy at a fraction of the computational cost of all existing methods.

Key advantages of JKOnet*

  • Outperforms existing baselines in sample efficiency, computational complexity, and accuracy.

  • Learns the different components of the diffusion process, including potential, interaction, and internal energy.

  • Provides a closed-form optimal solution for linearly parametrized functionals.

  • Achieves state-of-the-art accuracy in predicting cellular process evolution at a fraction of the computational cost of existing methods.

JKOnet* vs JKOnet

Potential energy

Interaction energy

Internal energy

Speed

JKOnet

slow

JKOnet*

fast 🔥

Our methodology is based on the interpretation of diffusion processes as energy-minimizing trajectories in the probability space via the so-called JKO scheme, which we study via its first-order optimality conditions.

Check out the paper for an intuition as well as an in-depth explanation and thorough comparisons with existing methods.

Citation 🙏

If you use this code in your research, please cite our paper (NeurIPS 2024, Oral Presentation):

@article{terpin2024learning,
   title={Learning diffusion at lightspeed},
   author={Terpin, Antonio and Lanzetti, Nicolas and Gadea, Mart{\'\i}n and D\"{o}rfler, Florian},
   journal={Advances in Neural Information Processing Systems},
   volume={37},
   pages={6797--6832},
   year={2024}
}

Contact and contributing

If you have any questions, want to signal an error or contribute to the project, feel free to reach out to Antonio Terpin via email: aterpin@ethz.ch or directly open an issue/PR on the GitHub repository.