Learning single-cell diffusion dynamics 🧬

Understanding the time evolution of cellular processes subject to external stimuli is a fundamental open question in biology. Motivated by the intuition that cells differentiate minimizing some energy functional, we deploy JKOnet* to analyze the embryoid body single-cell RNA sequencing (scRNA-seq) data [6] describing the differentiation of human embryonic stem cells over a period of 27 days.

Generating the data

The dataset used is the one from Moon [6]. The dataset is publicly available as scRNAseq.zip at Mendeley Datasets at this link. This dataset tracks the differentiation of human embryonic stem cells over a 27-day period, with cell snapshots collected at the following time intervals:

Timesteps

\(t_{0}\)

\(t_{1}\)

\(t_{2}\)

\(t_{3}\)

\(t_{4}\)

Days

0 - 3

6 - 9

12 - 15

18 - 21

24 - 27

We follow the data pre-processing in [8] and [7]; in particular, we use the same processed artifacts of the embryoid data provided in their work, which contains the first 100 components of the principal components analysis (PCA) of the data. The data is located in data/TrajectoryNet/eb_velocity_v5_npz.

To start, we load the data, scale it as in the dataset.py of [8], and save it in the format needed for our preprocessing script. In particular, we consider the first 5 principal components. The preprocessing script creates two files in /data/RNA_PCA_5. RNA_PCA_5 is the dataset_name to use in the training script.

python preprocess_rna_seq.py --n-components 5

Next, we generate the data for training and evaluation. We perform a 60-40 train-test split, and for the training data we compute the couplings for training JKOnet*:

python data_generator.py --load-from-file RNA_PCA_5 --test-ratio 0.4 --split-population

Training

To train and evaluate the model we run the train.py script.

python train.py --dataset RNA_PCA_5 --solver jkonet-star-time-potential --epochs 100

We also provide the following scripts to run all the experiments:

bash -x scripts/exp_rna_jkonet_star.sh
bash -x scripts/exp_rna_jkonet.sh
bash -x scripts/exp_rna_jkonet_vanilla.sh

Note

Qualitative analysis of the dataset suggests that the energy governing cell evolution might be well-described by a time-varying potential, in line with previous work [7, 8]. For this, we use the jkonet-star-time-potential solver, which incorporates time as a parameter in the model. Check out the paper for more details.

Results

To evaluate quantitatively the quality of our results, we train our models on \(60\%\) of the data at each timestep, using only the first \(5\) principal components, and we compute the one-step-ahead Earth Mover’s Distance (Wasserstein-1 error) on the test data:

\[W_{1}(\mu, \hat{\mu}_t) = \min_{\gamma \in \Pi(\mu_t, \hat{\mu}_t)} \int_{\mathbb{R}^d \times \mathbb{R}^d} \|x - y\| \, \mathrm{d}\gamma(x, y).\]

We juxtapose our numerical results with recent work in the literature on the first \(5\) Principal Components of the embryoid body scRNA-seq dataset. Their numerical values are taken from [1] (Table 5) and [7] (Table 4).

The following table gathers all the results.

Algorithm

EMD Score

TrajectoryNet [8]

\(0.848 \pm --\)

Reg. CNF [3]

\(0.825 \pm --\)

DSB [2]

\(0.862 \pm 0.023\)

I-CFM [7]

\(0.872 \pm 0.087\)

SB-CFM [7]

\(1.221 \pm 0.380\)

OT-CFM [7]

\(0.790 \pm 0.068\)

NLSB [5]

\(0.74 \pm --\)

MIOFLOW [4]

\(0.79 \pm --\)

DMSB [1]

\(0.67 \pm --\)

JKOnet*

\(0.623 \pm 0.04\)

Note

The literature is fragmented in terms of comparing the various methods for learning diffusion terms in the scRNA data. For instance, the numbers for the EMD in [2, 3, 7, 8] are computed leaving out one time point for validation, while [1, 4, 5] compare generative samples to ground thruth data. For this, we limit ourselves to say that JKOnet* seems to perform as well as the best methods in the literature, while being significantly faster to train.

Below, we display the time evolution of the first two principal components of the level curves of the potential energy minimized by the cells, along with the cells trajectory (in green the data, in blue the interpolated predictions).

RNA

The top row shows the two principal components of the scRNA-seq data, ground truth (green, days 1-3, 6-9, 12-15, 18-21, 24-27) and interpolated (blue, days 4-5, 10-11, 16-17, 22-23). The bottom row displays the estimated potential level curves over time. The bottom left plot superimposes the same three level curves for days 1-3 (solid), 12-15 (dashed), and 24-27 (dashed with larger spaces) to highlight the time-dependency.

[1] (1,2,3)

Tianrong Chen, Guan-Horng Liu, Molei Tao, and Evangelos Theodorou. Deep momentum multi-marginal schrödinger bridge. Advances in Neural Information Processing Systems, 2024.

[2] (1,2)

Valentin De Bortoli, James Thornton, Jeremy Heng, and Arnaud Doucet. Diffusion Schrödinger bridge with applications to score-based generative modeling. Advances in Neural Information Processing Systems, 34:17695–17709, 2021.

[3] (1,2)

Chris Finlay, Jörn-Henrik Jacobsen, Levon Nurbekyan, and Adam Oberman. How to train your neural ode: the world of jacobian and kinetic regularization. In International Conference on Machine Learning, 3154–3164. PMLR, 2020.

[4] (1,2)

Guillaume Huguet, Daniel Sumner Magruder, Alexander Tong, Oluwadamilola Fasina, Manik Kuchroo, Guy Wolf, and Smita Krishnaswamy. Manifold interpolating optimal-transport flows for trajectory inference. Advances in Neural Information Processing Systems, 35:29705–29718, 2022.

[5] (1,2)

Takeshi Koshizuka and Issei Sato. Neural Lagrangian Schrödinger bridge: diffusion modeling for population dynamics. arXiv preprint arXiv:2204.04853, 2022.

[6] (1,2)

Kevin R Moon, David Van Dijk, Zheng Wang, Scott Gigante, Daniel B Burkhardt, William S Chen, Kristina Yim, Antonia van den Elzen, Matthew J Hirn, Ronald R Coifman, and others. Visualizing structure and transitions in high-dimensional biological data. Nature biotechnology, 37(12):1482–1492, 2019.

[7] (1,2,3,4,5,6,7)

Alexander Tong, Kilian FATRAS, Nikolay Malkin, Guillaume Huguet, Yanlei Zhang, Jarrid Rector-Brooks, Guy Wolf, and Yoshua Bengio. Improving and generalizing flow-based generative models with minibatch optimal transport. Transactions on Machine Learning Research, 2024.

[8] (1,2,3,4,5)

Alexander Tong, Jessie Huang, Guy Wolf, David Van Dijk, and Smita Krishnaswamy. Trajectorynet: a dynamic optimal transport network for modeling cellular dynamics. In International Conference on Machine Learning, 9526–9536. PMLR, 2020.