Benchmarks π₯ο
In this page we report the benchmarks for the JKOnet* model on the synthetic data. For the results related to the single-cell data, please refer to the Learning single-cell diffusion dynamics 𧬠page. Check also the paper for more details.
Our modelsο
We use the following terminology for our methods. JKOnet* is the most general non-linear parametrization and JKOnet*V introduces the inductive bias \(\theta_2 = \theta_3 = 0\). Similarly, we refer to the linear parametrizations as JKOnet*l,V and JKOnet*l.
Metricsο
To evaluate the prediction capabilities, we use the one-step-ahead earth-mover distance (EMD), defined as
where \(\mu_t\) and \(\hat{\mu}_t\) are the observed and predicted populations, respectively. In particular, we consider the average and standard deviation over a trajectory.
Note
The scripts to perform the experiments rely on parallel
for parallelization and have been only tested on Ubuntu and MacOS. Please refer to the Installation guide page. If you make them work on Windows or Docker, we can include the instructions here. We did not look into that. You can also reproduce the single results manually using Docker (see the Installation guide page).
Note
The scripts log the results to wandb. Make sure to have a working installation. Please check the instructions, or remove the flag --wandb
from the scripts you are using.
Experiment 4.1: Training at lightspeedο
Experimental Settingο
We compare (i) the EMD error, (ii) the convergence ratio, and (iii) the time per epoch required by the different methods on a synthetic dataset consisting of particles subject to a non-linear drift, \(x_{t+1} = x_t - \tau \nabla V(x_t)\), where \(\tau = 0.01\), \(T = 5\), and the potential functions \(V(x)\) in utils.functions
.
Resultsο
The figure below, which is composed of three plots, collects all the numerical results of the experiment. The scatter plot displays points \((x_i, y_i)\), where \(x_i\) indexes the potentials in utils.functions
and \(y_i\) represents the errors (EMD, normalized such that the maximum error among all models and all potentials is 1) obtained with the different models. Each method that diverged during training is marked with NaN. The plot in the bottom-left shows the EMD error trajectory during training, normalized such that 0 and 1 represent the minimum and maximum EMD, respectively, and averaged over all experiments. The shaded area indicates the standard deviation. Additionally, the box plot analyzes the time per epoch required by each method, with statistics compiled across all epochs and all potential energies.
All our methods perform uniformly better than the baseline, regardless of the generality. The speed improvement of the JKOnetβ models family suggests that a theoretically guided loss may provide strong computational benefits on par with sophisticated model architectures. Our linearly parametrized models, JKOnet*l and JKOnet*l,V, require a computational time per epoch comparable to the JKOnet family, but they only need one epoch to solve the problem optimally. Our non-linear models, JKOnet*and JKOnet*V, instead both require significantly lower time per epoch and converge faster than the JKOnet family. Compared to JKOnet, our model also requires a simpler architecture: we drop the additional ICNN used in the inner iteration and the related training details. Notice that simply replacing the ICNN in JKOnet with a vanilla MLP deprives the method of the theoretical connections with optimal transport, which, in our experiments, appears to be associated with stability (NaN in the topmost plot).
Running the experimentο
We provide the following scripts to run all the experiments:
bash -x scripts/exp1_jkonet_star_potential.sh
bash -x scripts/exp1_jkonet_star.sh
bash -x scripts/exp1_jkonet_star_linear_potential.sh
bash -x scripts/exp1_jkonet_star_linear.sh
bash -x scripts/exp1_jkonet.sh
bash -x scripts/exp1_jkonet_vanilla.sh
Post-processingο
To retrieve the results from wandb and write them into a file for later visualization, we provide the following option:
python scripts/exp1_plot.py
Experiment 4.2: Scaling lawsο
Experimental Settingο
We evaluate the performance of JKOnet*V to recover the correct potential energy given \(N \in \{1000, 2500, 5000, 7500, 10000\}\) across dimensions \(d \in \{10, 20, 30, 40, 50\}\).
Resultsο
Below we display the EMD error obtained for every configuration. The stable color along the rows suggests an almost constant error (the EMD error is related to the Euclidean norm and, thus, is expected to grow linearly with the dimension \(d\); here, the growth is strongly sublinear) up to the point where the number of particles is not informative enough (along the columns, the error decreases again). The time complexity of the computation of the optimal transport plans is influenced linearly by the dimensionality d, and is negligible compared to the solution of the linear program, which depends only on the number of particles, check the paper for more details. We thus conclude that JKOnetβ is well suited for high-dimensional tasks.
Running the experimentο
We provide the following script to run all the experiments:
bash -x scripts/exp2.sh
Post-processingο
To retrieve the results from wandb and write them into a file for later visualization, we provide the following option:
python scripts/exp2_plot.py
Experiment 4.3: General energy functionalsο
Experimental Settingο
We showcase the capabilities of the JKOnetβ models to recover the potential, interaction, and internal energies selected as combinations of the functions in utils.functions
and noise levels \(\beta \in \{0.0, 0.1, 0.2\}\). To our knowledge, this is the first model to recover all three energy terms.
Resultsο
Below are collected the numerical results of the experiment. Compared to the setting in Experiment 4.1, there are two additional sources of inaccuracies: (i) the noise, which introduces an inevitable sampling error, and the (ii) the estimation of the densities (check the paper for more details). Nonetheless, the low EMD errors demonstrate the capability of JKOnetβ to recover the energy components that best explain the observed populations.
Running the experimentο
We provide the following script to run all the experiments:
bash -x scripts/exp3.sh
Post-processingο
To retrieve the results from wandb and write them into a file for later visualization, we provide the following option:
python scripts/exp3_plot.py
Note
The _plot.py
scripts generate the data we rendered in the paper, but youβre on your own when it comes to generating the plots (we like tikz). π If you want to implement the plotting in python and contribute to the repo, we would be very happy to accept a PR!
Note
To reproduce the results faster, you can also reduce the number of epochs to 100. You can also change the evaluation frequency to every 1000 epochs. The results will not change substantially.
Note
Different machines may yield slightly different results, but they should not change substantially. If they do in your setup, please let us know.