Learning single-cell diffusion dynamics 🧬
Understanding the time evolution of cellular processes subject to external stimuli is a fundamental open question in biology. Motivated by the intuition that cells differentiate minimizing some energy functional, we deploy JKOnet* to analyze the embryoid body single-cell RNA sequencing (scRNA-seq) data [6] describing the differentiation of human embryonic stem cells over a period of 27 days.
Generating the data
The dataset used is the one from Moon [6]. The dataset is publicly available as
scRNAseq.zip
at Mendeley Datasets at this link.
This dataset tracks the differentiation of human embryonic stem cells over a 27-day period, with cell snapshots
collected at the following time intervals:
Timesteps |
\(t_{0}\) |
\(t_{1}\) |
\(t_{2}\) |
\(t_{3}\) |
\(t_{4}\) |
Days |
0 - 3 |
6 - 9 |
12 - 15 |
18 - 21 |
24 - 27 |
We follow the data pre-processing in [8] and [7]; in particular,
we use the same processed artifacts of the embryoid data provided in their work, which contains the first 100
components of the principal components analysis (PCA) of the data.
The data is located in data/TrajectoryNet/eb_velocity_v5_npz
.
To start, we load the data, scale it as in the dataset.py of [8], and save it in the format needed for our preprocessing script. In particular, we consider the first 5 principal components. The preprocessing script creates two files in /data/RNA_PCA_5
. RNA_PCA_5
is the dataset_name
to use in the training script.
python preprocess_rna_seq.py --n-components 5
Next, we generate the data for training and evaluation. We perform a 60-40 train-test split, and for the training data we compute the couplings for training JKOnet*:
python data_generator.py --load-from-file RNA_PCA_5 --test-ratio 0.4 --split-population
Training
To train and evaluate the model we run the train.py
script.
python train.py --dataset RNA_PCA_5 --solver jkonet-star-time-potential --epochs 100
We also provide the following scripts to run all the experiments:
bash -x scripts/exp_rna_jkonet_star.sh
bash -x scripts/exp_rna_jkonet.sh
bash -x scripts/exp_rna_jkonet_vanilla.sh
Note
Qualitative analysis of the dataset suggests that the energy governing cell evolution might be well-described by a time-varying potential, in line with previous work [7, 8]. For this, we use the jkonet-star-time-potential
solver, which incorporates time as a parameter in the model. Check out the paper for more details.
Results
To evaluate quantitatively the quality of our results, we train our models on \(60\%\) of the data at each timestep, using only the first \(5\) principal components, and we compute the one-step-ahead Earth Mover’s Distance (Wasserstein-1 error) on the test data:
We juxtapose our numerical results with recent work in the literature on the first \(5\) Principal Components of the embryoid body scRNA-seq dataset. Their numerical values are taken from [1] (Table 5) and [7] (Table 4).
The following table gathers all the results.
Algorithm |
EMD Score |
---|---|
TrajectoryNet [8] |
\(0.848 \pm --\) |
Reg. CNF [3] |
\(0.825 \pm --\) |
DSB [2] |
\(0.862 \pm 0.023\) |
I-CFM [7] |
\(0.872 \pm 0.087\) |
SB-CFM [7] |
\(1.221 \pm 0.380\) |
OT-CFM [7] |
\(0.790 \pm 0.068\) |
NLSB [5] |
\(0.74 \pm --\) |
MIOFLOW [4] |
\(0.79 \pm --\) |
DMSB [1] |
\(0.67 \pm --\) |
JKOnet* |
\(0.623 \pm 0.04\) |
Note
The literature is fragmented in terms of comparing the various methods for learning diffusion terms in the scRNA data. For instance, the numbers for the EMD in [2, 3, 7, 8] are computed leaving out one time point for validation, while [1, 4, 5] compare generative samples to ground thruth data. For this, we limit ourselves to say that JKOnet* seems to perform as well as the best methods in the literature, while being significantly faster to train.
Below, we display the time evolution of the first two principal components of the level curves of the potential energy minimized by the cells, along with the cells trajectory (in green the data, in blue the interpolated predictions).
The top row shows the two principal components of the scRNA-seq data, ground truth (green, days 1-3, 6-9, 12-15, 18-21, 24-27) and interpolated (blue, days 4-5, 10-11, 16-17, 22-23). The bottom row displays the estimated potential level curves over time. The bottom left plot superimposes the same three level curves for days 1-3 (solid), 12-15 (dashed), and 24-27 (dashed with larger spaces) to highlight the time-dependency.
Tianrong Chen, Guan-Horng Liu, Molei Tao, and Evangelos Theodorou. Deep momentum multi-marginal schrödinger bridge. Advances in Neural Information Processing Systems, 2024.
Valentin De Bortoli, James Thornton, Jeremy Heng, and Arnaud Doucet. Diffusion Schrödinger bridge with applications to score-based generative modeling. Advances in Neural Information Processing Systems, 34:17695–17709, 2021.
Chris Finlay, Jörn-Henrik Jacobsen, Levon Nurbekyan, and Adam Oberman. How to train your neural ode: the world of jacobian and kinetic regularization. In International Conference on Machine Learning, 3154–3164. PMLR, 2020.
Guillaume Huguet, Daniel Sumner Magruder, Alexander Tong, Oluwadamilola Fasina, Manik Kuchroo, Guy Wolf, and Smita Krishnaswamy. Manifold interpolating optimal-transport flows for trajectory inference. Advances in Neural Information Processing Systems, 35:29705–29718, 2022.
Takeshi Koshizuka and Issei Sato. Neural Lagrangian Schrödinger bridge: diffusion modeling for population dynamics. arXiv preprint arXiv:2204.04853, 2022.
Kevin R Moon, David Van Dijk, Zheng Wang, Scott Gigante, Daniel B Burkhardt, William S Chen, Kristina Yim, Antonia van den Elzen, Matthew J Hirn, Ronald R Coifman, and others. Visualizing structure and transitions in high-dimensional biological data. Nature biotechnology, 37(12):1482–1492, 2019.