Quick Start
Single Problem
import torch
import mdot_tnt
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create marginals (probability distributions)
n, m = 512, 512
r = torch.rand(n, device=device, dtype=torch.float64)
r = r / r.sum()
c = torch.rand(m, device=device, dtype=torch.float64)
c = c / c.sum()
# Cost matrix (e.g., pairwise distances)
C = torch.rand(n, m, device=device, dtype=torch.float64)
# Solve for optimal transport cost
cost = mdot_tnt.solve_OT(r, c, C, gamma_f=1024)
# Or get the full transport plan
plan = mdot_tnt.solve_OT(r, c, C, gamma_f=1024, return_plan=True)
Batched Solving
When solving multiple OT problems, use the batched solver for significant speedup compared to sequential solution:
import torch
import mdot_tnt
device = "cuda"
batch_size, n, m = 32, 512, 512
# Multiple marginal pairs
r = torch.rand(batch_size, n, device=device, dtype=torch.float64)
r = r / r.sum(-1, keepdim=True)
c = torch.rand(batch_size, m, device=device, dtype=torch.float64)
c = c / c.sum(-1, keepdim=True)
# Shared cost matrix (or per-problem: shape [batch_size, n, m])
C = torch.rand(n, m, device=device, dtype=torch.float64)
# Solve all problems at once
costs = mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024) # Returns (batch_size,) tensor
The batched solver achieves speedup by amortizing GPU synchronization overhead across all problems in the batch.
Low-Memory / Point Cloud Solving
For large problems where allocating a full n × m cost matrix is prohibitive,
use mdot_tnt.lowmem.solve_OT_lowmem(). It processes columns in blocks of
size block_size (k), keeping only O(nk) entries in working memory at a time.
Point cloud mode (true O(nk + (n+m)d) memory — C is never materialised):
import torch
from mdot_tnt.lowmem import solve_OT_lowmem
device = "cuda" if torch.cuda.is_available() else "cpu"
n, m, d = 10000, 10000, 64
X = torch.rand(n, d, device=device, dtype=torch.float64)
Y = torch.rand(m, d, device=device, dtype=torch.float64)
r = torch.ones(n, device=device, dtype=torch.float64) / n
c = torch.ones(m, device=device, dtype=torch.float64) / m
# block_size controls the memory / speed trade-off
cost = solve_OT_lowmem(r, c, X=X, Y=Y, gamma_f=1024, block_size=512)
Dense matrix mode (full C provided, but only k columns loaded at once):
C = torch.rand(n, m, device=device, dtype=torch.float64)
cost = solve_OT_lowmem(r, c, C=C, gamma_f=1024, block_size=512)
Memory / speed trade-off via block_size:
block_size = m(default): fastest, same memory assolve_OT()block_size = int(m**0.5): good balanceblock_size = 1: minimum memory (slowest)
Performance Tips
Use float64 for
gamma_f > 1024(automatic conversion with warning).Normalize cost matrices to [0, 1] for numerical stability.
Use batched solver when solving multiple problems with shared structure.
Increase
gamma_ffor higher precision (error scales as \(O(\log n / \gamma)\) in the worst case, but can be much better).