Comparing MDOT-TNT with POT Solvers
This tutorial demonstrates MDOT-TNT on a large random OT problem (n=14,000) and compares its speed and accuracy against the POT library’s exact solver (EMD), Sinkhorn, and Greenkhorn.
First, install the POT package for comparison. https://pythonot.github.io/
[1]:
!pip install POT
Requirement already satisfied: POT in /h/314/kemertas/anaconda3/envs/mdot_tnt/lib/python3.10/site-packages (0.9.4)
Requirement already satisfied: numpy>=1.16 in /h/314/kemertas/anaconda3/envs/mdot_tnt/lib/python3.10/site-packages (from POT) (1.26.4)
Requirement already satisfied: scipy>=1.6 in /h/314/kemertas/anaconda3/envs/mdot_tnt/lib/python3.10/site-packages (from POT) (1.15.1)
Import packages.
[2]:
import gc
import time
import ot
import torch as th
from mdot_tnt import solve_OT
from mdot_tnt.rounding import round_altschuler
device = "cuda:0"
Add a function for sampling random OT problems.
[3]:
def sample_random_problem(N, M, dim=100):
# Sample some distributions r and c according to a Dirichlet distribution.
r = th.distributions.Dirichlet(th.ones(N)).sample()
c = th.distributions.Dirichlet(th.ones(M)).sample()
# Sample N points x and M points y from a multivariate normal distribution in 100D.
x = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((N,))
y = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((M,))
# Compute the cost matrix C = ||x - y||_2^2.
C = th.cdist(x, y, p=2) ** 2
C /= C.max()
# Change to double precision.
# r, c, C = r.double(), c.double(), C.double()
r /= r.sum()
c /= c.sum()
return r, c, C
Sample a large problem (n=14000)
[6]:
n = 14000
r, c, C = sample_random_problem(n, n)
r, c, C = r.double().to(device), c.double().to(device), C.double().to(device)
Let’s use the highly efficient exact solver (CPU-based) of the POT library and time it.
[7]:
r_, c_, C_ = r.cpu().numpy(), c.cpu().numpy(), C.cpu().numpy()
time_start = time.time()
cost_emd = ot.emd2(r_, c_, C_, numItermax=int(1e10))
elapsed = time.time() - time_start
print(f"OT Cost: {cost_emd:.10f}, Time: {elapsed:.3f}")
OT Cost: 0.3182945673, Time: 97.668
Now use mdot_tnt to tackle the problem on a GPU (NVIDIA RTX 2080 Ti in this case).
[8]:
time_start = time.time()
cost = solve_OT(
r, c, C, gamma_f=1000
) # gamma_f is the inverse of the final regularization weight (1e-3 here)
elapsed = time.time() - time_start
print(f"MDOT-TNT error: {cost - cost_emd:.3e}, Time: {elapsed:.3f}")
gc.collect()
th.cuda.empty_cache()
MDOT-TNT error: 8.803e-05, Time: 4.435
4-5 decimal precision with more than 20x speedup. Needless to say, the speedup can be better on higher-end GPUs. Let’s also check the speedup using FP32 precision.
[9]:
time_start = time.time()
cost = solve_OT(
r.float(), c.float(), C.float(), gamma_f=1000
) # gamma_f is the inverse of the final regularization weight (1e-3 here)
elapsed = time.time() - time_start
print(f"MDOT-TNT error: {cost - cost_emd:.3e}, Time: {elapsed:.3f}")
gc.collect()
th.cuda.empty_cache()
MDOT-TNT error: 8.821e-05, Time: 1.705
57x speedup on this random problem! Not bad!
If either marginal is known to have many tiny entries (is effectively a sparse vector), we can further accelerate computation by dropping those particles by setting drop_tiny=True. Note that this feature was not used in the paper for fairness in benchmarking, but can be useful in practice.
[10]:
# Set a random half of the entries of r and c to 1e-20, and renormalize.
r2 = r.clone()
c2 = c.clone()
r2[th.randperm(n)[: n // 2]] = 1e-20
c2[th.randperm(n)[: n // 2]] = 1e-20
r2 /= r2.sum()
c2 /= c2.sum()
[11]:
time_start = time.time()
cost_emd2 = ot.emd2(r2.cpu().numpy(), c2.cpu().numpy(), C.cpu().numpy(), numItermax=int(1e10))
elapsed = time.time() - time_start
print(f"OT Cost: {cost_emd:.10f}, Time: {elapsed:.3f}")
OT Cost: 0.3182945673, Time: 82.562
A similar runtime as before for the exact solver… Let’s rerun MDOT-TNT with drop_tiny=True.
[12]:
time_start = time.time()
cost = solve_OT(
r2, c2, C, gamma_f=1000, drop_tiny=True
) # gamma_f is the inverse of the final regularization weight (1e-3 here)
elapsed = time.time() - time_start
print(f"MDOT-TNT error: {cost - cost_emd2:.3e}, Time: {elapsed:.3f}")
gc.collect()
th.cuda.empty_cache()
Dropped 7028 entries from r and 7032 entries from c.
MDOT-TNT error: 8.172e-05, Time: 1.155
Same level of precision as before, but this time ~70x speedup. And now doing the same with FP32 precision.
[13]:
time_start = time.time()
cost = solve_OT(
r2.float(), c2.float(), C.float(), gamma_f=1000, drop_tiny=True
) # gamma_f is the inverse of the final regularization weight (1e-3 here)
elapsed = time.time() - time_start
print(f"MDOT-TNT error: {cost - cost_emd2:.3e}, Time: {elapsed:.3f}")
gc.collect()
th.cuda.empty_cache()
Dropped 7028 entries from r and 7032 entries from c.
MDOT-TNT error: 8.187e-05, Time: 0.535
154x speedup. Let’s go back to the original problem (dense marginals) and see how Sinkhorn fares; starting with strong regularization and gradually decreasing regularization weight.
[14]:
gc.collect()
th.cuda.empty_cache()
time_start = time.time()
plan = ot.sinkhorn(r, c, C, reg=1 / 100, numItermax=int(1e10))
plan = round_altschuler(plan, r, c)
cost = (plan * C).sum()
elapsed = time.time() - time_start
print(f"Sinkhorn error: {cost - cost_emd:.3e}, Time: {elapsed:.3f}")
del plan
Sinkhorn error: 1.458e-02, Time: 0.511
Remember the optimal cost is about 0.318. Relative error here is about 0.0146 * 100 / 0.318 = 4.6% (hardly negligible). Let’s run at the same temperature as MDOT-TNT.
[15]:
time_start = time.time()
plan = ot.sinkhorn(r, c, C, reg=1 / 1000, numItermax=int(1e10))
plan = round_altschuler(plan, r, c)
cost = (plan * C).sum()
elapsed = time.time() - time_start
print(f"Sinkhorn error: {cost - cost_emd:.3e}, Time: {elapsed:.3f}")
del plan
Sinkhorn error: 7.315e-05, Time: 67.102
MDOT-TNT exhibits 15x speedup (took 4.435 seconds under the same setup of dense vectors + FP64 precision). As we show in the paper, the gap grows with weaker regularization.
Let’s also give Greenkhorn by Altschuler et al. (2017) a try.
[17]:
gc.collect()
th.cuda.empty_cache()
time_start = time.time()
plan = ot.bregman.greenkhorn(r, c, C, reg=1 / 1000, numItermax=int(1e10))
plan = round_altschuler(plan, r, c)
cost = (plan * C).sum()
elapsed = time.time() - time_start
print(f"Greenkhorn error: {cost - cost_emd:.3e}, Time: {elapsed:.3f}")
del plan
Greenkhorn error: 7.461e-05, Time: 2929.723
For this value of n=14000, Greenkhorn suffers from low GPU utilization. Even if the total number of row or column updates are fewer than those of Sinkhorn, in practice, it is substantially slower because of limited parallelization, updating one row/column at a time.