API Reference
Main Functions
- mdot_tnt.solve_OT(r, c, C, gamma_f=1024.0, drop_tiny=False, return_plan=False, round=True, log=False, devices=None, num_gpus=None)[source]
Solve the entropic-regularized optimal transport problem.
Inputs r, c, C are required to be torch tensors.
- Parameters:
r – Row marginal of shape (n,). Must sum to 1.
c – Column marginal of shape (m,). Must sum to 1.
C – Cost matrix of shape (n, m). Recommended to scale entries to [0, 1].
gamma_f – Temperature (inverse of the regularization weight). For many problems, stable up to 2^18. Higher values return more accurate solutions but take longer to converge. Use double precision if gamma_f is large.
drop_tiny – If either marginal is known to be sparse, set this to True to drop tiny entries for a speedup. If return_plan is True, the returned plan will be in the original dimensions.
return_plan – If True, return the optimal transport plan rather than the cost.
round – If True, use the rounding algorithm of Altschuler et al. (2017) to (a) return a feasible plan if return_plan is True and (b) the cost of the rounded plan if return_plan is False.
log – If True, additionally return a dictionary containing logs of the optimization process.
devices – List of CUDA devices (
int,str, ortorch.device) to distribute column blocks of C across. Enables multi-GPU solving of a single OT problem by splitting the cost matrix along the column dimension and running blocked operations in parallel. Pass[0, 1]or[torch.device('cuda:0'), torch.device('cuda:1')]. Mutually exclusive withnum_gpus. Ignored when only one device is given.num_gpus – Number of CUDA devices to use (starting from index 0). Mutually exclusive with
devices.
- Returns:
Transport cost (scalar tensor) if return_plan is False, or the transport plan of shape (n, m) if return_plan is True. If log is True, returns a tuple of (result, logs_dict).
- mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024.0, drop_tiny=False, return_plan=False, round=True, log=False, devices=None, num_gpus=None)[source]
Solve multiple entropic-regularized optimal transport problems in a single batched call.
This function provides significant speedup (5-10x) over solving problems sequentially by amortizing GPU synchronization overhead across all problems in the batch.
- Parameters:
r (
Tensor) – (batch, n) row marginals. Each row must sum to 1.c (
Tensor) – (batch, m) column marginals. Each row must sum to 1.C (
Tensor) – Cost matrix. Either (n, m) for shared cost across all problems, or (batch, n, m) for per-problem costs. Recommended to scale to [0, 1].gamma_f (
float) – Temperature (inverse of regularization weight). Higher values give more accurate solutions but take longer. Stable up to ~2^18 with float64.drop_tiny (
bool) – Not supported in batched solver. Raises NotImplementedError if True.return_plan (
bool) – If True, return transport plans instead of costs.round (
bool) – If True, apply Altschuler rounding for feasible solutions.log (
bool) – If True, also return optimization logs.devices (
Optional[Any]) – List of CUDA devices (int,str, ortorch.device) to distribute the batch across. The batch is split evenly; each GPU solves its share independently. Pass[0, 1]or[torch.device('cuda:0'), torch.device('cuda:1')]. Mutually exclusive withnum_gpus. Ignored when only one device is given.num_gpus (
Optional[int]) – Number of CUDA devices to use (starting from index 0). Mutually exclusive withdevices.
- Returns:
(batch,) tensor of transport costs. If return_plan is True: (batch, n, m) tensor of transport plans. If log is True: tuple of (result, logs_dict).
- Return type:
If return_plan is False
Example
>>> # Solve 32 OT problems of size 512×512 >>> r = torch.rand(32, 512, device='cuda', dtype=torch.float64) >>> r = r / r.sum(-1, keepdim=True) >>> c = torch.rand(32, 512, device='cuda', dtype=torch.float64) >>> c = c / c.sum(-1, keepdim=True) >>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64) >>> costs = solve_OT_batched(r, c, C, gamma_f=1024.)
- mdot_tnt.lowmem.solve_OT_lowmem(r, c, C=None, X=None, Y=None, cost_fn=None, gamma_f=1024.0, block_size=None, drop_tiny=False, return_plan=False, round=True, log=False, devices=None, num_gpus=None)[source]
Solve entropic-regularized OT with O(nk) working memory.
Accepts either a dense cost matrix or point clouds with a cost function. In point-cloud mode the cost matrix is never materialised, giving true O(nk + (n+m)d) memory.
- Parameters:
r (
Tensor) – Row marginal, shape (n,). Must sum to 1.c (
Tensor) – Column marginal, shape (m,). Must sum to 1.C (
Optional[Tensor]) – Dense cost matrix, shape (n, m). Mutually exclusive with X/Y. Recommended to scale entries to [0, 1].X (
Optional[Tensor]) – Source points, shape (n, d). Use with Y and optionally cost_fn.Y (
Optional[Tensor]) – Target points, shape (m, d). Use with X.cost_fn (
Optional[Callable[[Tensor,Tensor],Tensor]]) –(X, Y_block) -> C_blockof shape (n, k). Defaults tosquared_euclidean. Only used when X/Y are provided.gamma_f (
float) – Temperature (inverse regularization weight).block_size (
Optional[int]) – Columns per block. Controls memory / compute trade-off: - None or m: fastest, O(nm) or O(nk) memory - sqrt(m): good trade-off - 1: minimum memory (slowest)drop_tiny (
bool) – Drop tiny marginal entries for speedup with sparse marginals.return_plan (
bool) – If True, return the full transport plan. Note: the plan is inherently O(nm); in point-cloud mode it is built block-by-block so that C and P are never both in memory simultaneously.round (
bool) – Use Altschuler rounding for feasible plans / costs.log (
bool) – Additionally return optimization logs.devices (
Optional[Any]) – List of CUDA devices (int,str, ortorch.device) to distribute column blocks across. Pass[0, 1]or[torch.device('cuda:0'), torch.device('cuda:1')]. Mutually exclusive withnum_gpus. When a single device is given the standard single-GPU solver is used.num_gpus (
Optional[int]) – Number of CUDA devices to use (starting from index 0). Mutually exclusive withdevices.
- Returns:
Transport cost (scalar) if return_plan is False, or the transport plan (n, m) if return_plan is True. If log is True, returns (result, logs).
Core Modules
mdot
Core MDOT solver using truncated Newton projection.
- mdot_tnt.mdot.preprocess_marginals(r, c, C, eps)[source]
Drop the smallest marginal entries whose cumulative sum is below a threshold.
- Parameters:
- Returns:
(r_new, r_keep): The new row marginal and indices of kept entries.
(c_new, c_keep): The new column marginal and indices of kept entries.
C: The cost matrix with corresponding rows and columns dropped.
- Return type:
A tuple containing
- mdot_tnt.mdot.smooth_marginals(r, c, eps, w_r=0.5, w_c=0.5)[source]
Smooth the marginals by adding a small amount of uniform mass to each entry.
- Parameters:
- Return type:
- Returns:
A tuple (r_hat, c_hat) of smoothed marginals with total TV distance at most eps from the original marginals.
- mdot_tnt.mdot.adjust_schedule(q, deltas=None)[source]
Adjust the temperature annealing schedule based on the success of the Truncated Newton method.
- mdot_tnt.mdot.mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0)[source]
Solve the entropic-regularized optimal transport problem using the MDOT method.
This implements the MDOT method introduced in the paper: “Efficient and Accurate Optimal Transport with Mirror Descent and Conjugate Gradients” by Mete Kemertas, Allan D. Jepson and Amir-massoud Farahmand. URL: https://arxiv.org/abs/2307.08507
Here, we use the Truncated Newton method for projection.
- Parameters:
r (
Tensor) – The first marginal of shape (n,).c (
Tensor) – The second marginal of shape (m,).C (
Tensor) – The cost matrix of shape (n, m). Recommended to scale entries to [0, 1].gamma_f (
float) – The final temperature (inverse of the regularization weight).gamma_i (
float) – The initial temperature.p (
float) – The exponent for the epsilon function, used to determine the stopping criterion for the dual gradient.q (
float) – The temperature annealing (or mirror descent step size) schedule adjustment factor.
- Returns:
u: The row dual variables of shape (n,).
v: The column dual variables of shape (m,).
gamma: The final temperature achieved.
k_total: The total number of O(n^2) primitive operations.
logs: Dictionary with optimization statistics.
- Return type:
A tuple containing
batched
Batched MDOT-TNT solver for solving multiple optimal transport problems simultaneously.
This module provides batched versions of the MDOT-TNT solver that achieve significant speedups (5-10x) over sequential solving by amortizing GPU synchronization overhead across all problems in a batch.
Key insight: The main solver has many Python while-loops that check convergence, each requiring a GPU→CPU sync. By batching N problems together, we do one sync per iteration for the entire batch instead of N syncs.
Supports: - Multiple marginal pairs with shared cost matrix: r, c shape (batch, n), C shape (n, m) - Multiple OT problems with different costs: r, c shape (batch, n), C shape (batch, n, m)
- Example usage:
>>> import torch >>> from mdot_tnt.batched import solve_OT_batched >>> >>> # 32 problems, each 512-dimensional >>> r = torch.rand(32, 512, device='cuda', dtype=torch.float64) >>> r = r / r.sum(dim=-1, keepdim=True) >>> c = torch.rand(32, 512, device='cuda', dtype=torch.float64) >>> c = c / c.sum(dim=-1, keepdim=True) >>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64) # Shared cost >>> >>> costs = solve_OT_batched(r, c, C, gamma_f=1024.) >>> print(costs.shape) # (32,)
- class mdot_tnt.batched.BatchedTruncatedNewtonProjector(device, dtype, **kwargs)[source]
Batched Truncated Newton projector for the MDOT algorithm.
Projects onto the set of couplings satisfying marginal constraints, processing multiple problems simultaneously for efficiency.
- project(gamma_C, log_r, log_c, eps_d, u, v, active_mask=None)[source]
Project onto the constraint set for all problems in the batch.
- Parameters:
gamma_C (
Tensor) – (batch, n, m) or (n, m) cost matrix scaled by gamma.log_r (
Tensor) – (batch, n) log of row marginals.log_c (
Tensor) – (batch, m) log of column marginals.eps_d (
Union[float,Tensor]) – Convergence tolerance, scalar or (batch,) tensor.u (
Tensor) – (batch, n) initial row dual variables.v (
Tensor) – (batch, m) initial column dual variables.active_mask (
Optional[Tensor]) – (batch,) bool tensor, True for problems to process.
- Returns:
(batch, n) updated row dual variables. v: (batch, m) updated column dual variables. logs: Dictionary with optimization statistics. success: (batch,) bool tensor indicating convergence per problem.
- Return type:
u
- mdot_tnt.batched.solve_OT_batched(r, c, C, gamma_f=1024.0, drop_tiny=False, return_plan=False, round=True, log=False, devices=None, num_gpus=None)[source]
Solve multiple entropic-regularized optimal transport problems in a single batched call.
This function provides significant speedup (5-10x) over solving problems sequentially by amortizing GPU synchronization overhead across all problems in the batch.
- Parameters:
r (
Tensor) – (batch, n) row marginals. Each row must sum to 1.c (
Tensor) – (batch, m) column marginals. Each row must sum to 1.C (
Tensor) – Cost matrix. Either (n, m) for shared cost across all problems, or (batch, n, m) for per-problem costs. Recommended to scale to [0, 1].gamma_f (
float) – Temperature (inverse of regularization weight). Higher values give more accurate solutions but take longer. Stable up to ~2^18 with float64.drop_tiny (
bool) – Not supported in batched solver. Raises NotImplementedError if True.return_plan (
bool) – If True, return transport plans instead of costs.round (
bool) – If True, apply Altschuler rounding for feasible solutions.log (
bool) – If True, also return optimization logs.devices (
Optional[Any]) – List of CUDA devices (int,str, ortorch.device) to distribute the batch across. The batch is split evenly; each GPU solves its share independently. Pass[0, 1]or[torch.device('cuda:0'), torch.device('cuda:1')]. Mutually exclusive withnum_gpus. Ignored when only one device is given.num_gpus (
Optional[int]) – Number of CUDA devices to use (starting from index 0). Mutually exclusive withdevices.
- Returns:
(batch,) tensor of transport costs. If return_plan is True: (batch, n, m) tensor of transport plans. If log is True: tuple of (result, logs_dict).
- Return type:
If return_plan is False
Example
>>> # Solve 32 OT problems of size 512×512 >>> r = torch.rand(32, 512, device='cuda', dtype=torch.float64) >>> r = r / r.sum(-1, keepdim=True) >>> c = torch.rand(32, 512, device='cuda', dtype=torch.float64) >>> c = c / c.sum(-1, keepdim=True) >>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64) >>> costs = solve_OT_batched(r, c, C, gamma_f=1024.)
truncated_newton
Truncated Newton projector for the MDOT algorithm.
- class mdot_tnt.truncated_newton.TruncatedNewtonProjector(device, dtype, **kwargs)[source]
Truncated Newton projector for the MDOT algorithm.
Projects onto the set of couplings satisfying marginal constraints using a preconditioned conjugate gradient method within a Newton framework.
- project(gamma_C, log_r, log_c, eps_d, u, v)[source]
Project onto the set of couplings that satisfy the marginal constraints.
- Parameters:
gamma_C (
Tensor) – The cost matrix scaled by gamma, shape (n, m).log_r (
Tensor) – Log of row marginals, shape (n,).log_c (
Tensor) – Log of column marginals, shape (m,).eps_d (
Union[float,Tensor]) – Convergence tolerance for the dual gradient norm.u (
Tensor) – Initial row dual variables, shape (n,).v (
Tensor) – Initial column dual variables, shape (m,).
- Returns:
Updated row dual variables. v: Updated column dual variables. logs: Dictionary with optimization statistics. success: Whether projection converged successfully.
- Return type:
u
rounding
This file contains the implementation of the rounding algorithm proposed by Altschuler et al. (2017) in the paper “Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration”. The algorithm is used to round the transport plan obtained from the Sinkhorn algorithm to a feasible transport plan in the set U(r, c), where r and c are the row and column marginals, respectively. The algorithm is used in the mdot.py file to round the transport plan and compute the cost of the rounded plan. The implementation is based on the original paper.
- mdot_tnt.rounding.round_altschuler(P, r, c)[source]
Performs rounding given a transport plan and marginals.
- mdot_tnt.rounding.rounded_cost_altschuler(u, v, r, c, C, gamma)[source]
Performs rounding and cost computation in log-domain given dual variables.
This function computes the transport cost without storing the full n×m transport plan, making it memory efficient.
- Parameters:
u (
Tensor) – Dual variable for rows of shape (n,).v (
Tensor) – Dual variable for columns of shape (m,).r (
Tensor) – Row marginal of shape (n,).c (
Tensor) – Column marginal of shape (m,).C (
Tensor) – Cost matrix of shape (n, m).gamma (
Union[float,Tensor]) – Temperature (inverse of the entropic regularization weight).
- Return type:
- Returns:
The optimal transport cost as a scalar tensor.
lowmem
Low-memory MDOT-TNT solver using column-batched materialization.
- Achieves O(nk + (n+m)d) working memory by:
Never materializing the full n x m cost matrix C — cost blocks C[:, s:e] are computed on-the-fly from point clouds (X, Y) and a cost function.
Never materializing the full n x m transport plan P — plan blocks are computed on-the-fly from dual variables and cost blocks.
- The user controls the trade-off via block_size (k):
k = m: O(nm) memory, fastest (same as standard algorithm)
k = sqrt(m): O(n*sqrt(m)) memory (good trade-off)
k = 1: O(n) memory, slowest (like element-wise PyKeOps)
- Supports two input modes:
Point cloud mode: X (n, d), Y (m, d), cost_fn — truly O(nk + (n+m)d)
Dense matrix mode: C (n, m) — O(nm) for C storage, O(nk) working memory
- Based on:
“A Truncated Newton Method for Optimal Transport” by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
- mdot_tnt.lowmem.squared_euclidean(X, Y_block)[source]
Squared Euclidean cost: C[i,j] = ||X[i] - Y[j]||^2.
- class mdot_tnt.lowmem.LowMemoryTruncatedNewtonProjector(device, dtype, block_size, cost_block_fn, **kwargs)[source]
Low-memory variant of the Truncated Newton projector.
Instead of storing the full cost matrix or transport plan, cost blocks are computed on-the-fly via a user-supplied
cost_block_fn. All LSE reductions, Hessian-vector products, and diagonal preconditioner computations are performed block-by-block with O(nk) working memory.- Parameters:
device (
device) – PyTorch device for computations.dtype (
dtype) – Data type for tensors.block_size (
int) – Number of columns to process at once.cost_block_fn (
Callable[[int,int],Tensor]) –(start, end) -> C[:, start:end]of shape(n, end-start).**kwargs (
Any) – Additional options (debug: bool for verbose output).
- project(gamma, log_r, log_c, eps_d, u, v)[source]
Project onto the set of couplings satisfying marginal constraints.
- Parameters:
gamma (
Union[float,Tensor]) – Temperature (inverse regularization weight) for this projection step.log_r (
Tensor) – Log of row marginals, shape (n,).log_c (
Tensor) – Log of column marginals, shape (m,).eps_d (
Union[float,Tensor]) – Convergence tolerance for the dual gradient norm.u (
Tensor) – Initial row dual variables, shape (n,).v (
Tensor) – Initial column dual variables, shape (m,).
- Returns:
Updated row dual variables. v: Updated column dual variables. logs: Dictionary with optimization statistics. success: Whether projection converged successfully.
- Return type:
u
- chi_sinkhorn(u, v, log_r, log_c, log_r_P, eps_chi, maxOps=inf)[source]
Chi-squared Sinkhorn iterations using blocked LSE.
- mdot_tnt.lowmem.blocked_rounded_cost(u, v, r, c, cost_block_fn, m, gamma, block_size)[source]
Compute the rounded transport cost without materializing C or P.
Mirrors
rounded_cost_altschulerbut accesses the cost matrix only throughcost_block_fn. Working memory is O(n * block_size).- Parameters:
u (
Tensor) – Dual variable for rows, shape (n,).v (
Tensor) – Dual variable for columns, shape (m,).r (
Tensor) – Row marginal, shape (n,).c (
Tensor) – Column marginal, shape (m,).cost_block_fn (
Callable[[int,int],Tensor]) –(s, e) -> C[:, s:e].m (
int) – Number of columns.block_size (
int) – Number of columns per block.
- Return type:
- Returns:
Transport cost (scalar tensor).
- mdot_tnt.lowmem.blocked_transport_cost(u, v, cost_block_fn, m, gamma, block_size)[source]
Compute unrounded transport cost without materializing C or P.
cost = sum_{i,j} exp(u_i + v_j - gamma * C_{ij}) * C_{ij}
- Parameters:
- Return type:
- Returns:
Transport cost (scalar tensor).
- mdot_tnt.lowmem.mdot_lowmem(r, c, cost_block_fn, gamma_f, block_size, gamma_i=16, p=1.5, q=2.0, devices=None)[source]
Solve entropic-regularized OT using MDOT with low-memory projection.
- Parameters:
r (
Tensor) – Row marginal, shape (n,).c (
Tensor) – Column marginal, shape (m,).cost_block_fn (
Callable[[int,int],Tensor]) –(s, e) -> C[:, s:e]— computes cost blocks lazily.gamma_f (
float) – Final temperature (inverse regularization weight).block_size (
int) – Number of columns per block for the projector.gamma_i (
float) – Initial temperature.p (
float) – Exponent for the epsilon schedule.q (
float) – Temperature annealing factor.devices (
Optional[Any]) – Optional list oftorch.deviceobjects for multi-GPU execution (frommdot_tnt.multigpu._resolve_devices()). When provided, column blocks are distributed across these devices and blocked primitives run in parallel.
- Returns:
Row dual variables, shape (n,). v: Column dual variables, shape (m,). gamma: Final temperature achieved. k_total: Total number of O(n^2) primitive operations. logs: Optimization statistics.
- Return type:
u
- mdot_tnt.lowmem.solve_OT_lowmem(r, c, C=None, X=None, Y=None, cost_fn=None, gamma_f=1024.0, block_size=None, drop_tiny=False, return_plan=False, round=True, log=False, devices=None, num_gpus=None)[source]
Solve entropic-regularized OT with O(nk) working memory.
Accepts either a dense cost matrix or point clouds with a cost function. In point-cloud mode the cost matrix is never materialised, giving true O(nk + (n+m)d) memory.
- Parameters:
r (
Tensor) – Row marginal, shape (n,). Must sum to 1.c (
Tensor) – Column marginal, shape (m,). Must sum to 1.C (
Optional[Tensor]) – Dense cost matrix, shape (n, m). Mutually exclusive with X/Y. Recommended to scale entries to [0, 1].X (
Optional[Tensor]) – Source points, shape (n, d). Use with Y and optionally cost_fn.Y (
Optional[Tensor]) – Target points, shape (m, d). Use with X.cost_fn (
Optional[Callable[[Tensor,Tensor],Tensor]]) –(X, Y_block) -> C_blockof shape (n, k). Defaults tosquared_euclidean. Only used when X/Y are provided.gamma_f (
float) – Temperature (inverse regularization weight).block_size (
Optional[int]) – Columns per block. Controls memory / compute trade-off: - None or m: fastest, O(nm) or O(nk) memory - sqrt(m): good trade-off - 1: minimum memory (slowest)drop_tiny (
bool) – Drop tiny marginal entries for speedup with sparse marginals.return_plan (
bool) – If True, return the full transport plan. Note: the plan is inherently O(nm); in point-cloud mode it is built block-by-block so that C and P are never both in memory simultaneously.round (
bool) – Use Altschuler rounding for feasible plans / costs.log (
bool) – Additionally return optimization logs.devices (
Optional[Any]) – List of CUDA devices (int,str, ortorch.device) to distribute column blocks across. Pass[0, 1]or[torch.device('cuda:0'), torch.device('cuda:1')]. Mutually exclusive withnum_gpus. When a single device is given the standard single-GPU solver is used.num_gpus (
Optional[int]) – Number of CUDA devices to use (starting from index 0). Mutually exclusive withdevices.
- Returns:
Transport cost (scalar) if return_plan is False, or the transport plan (n, m) if return_plan is True. If log is True, returns (result, logs).