Source code for mdot_tnt

"""
MDOT-TNT: A Truncated Newton Method for Optimal Transport

This package provides efficient solvers for the entropic-regularized optimal transport
problem, as introduced in the paper "A Truncated Newton Method for Optimal Transport"
by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
URL: https://openreview.net/forum?id=gWrWUaCbMa

Main functions:
    solve_OT: Solve a single OT problem.
    solve_OT_batched: Solve multiple OT problems simultaneously (5-10x faster).

Example:
    >>> import torch
    >>> from mdot_tnt import solve_OT, solve_OT_batched
    >>>
    >>> # Single problem
    >>> r = torch.rand(512, device='cuda', dtype=torch.float64)
    >>> r = r / r.sum()
    >>> c = torch.rand(512, device='cuda', dtype=torch.float64)
    >>> c = c / c.sum()
    >>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64)
    >>> cost = solve_OT(r, c, C, gamma_f=1024.)
    >>>
    >>> # Batched (32 problems at once)
    >>> r_batch = torch.rand(32, 512, device='cuda', dtype=torch.float64)
    >>> r_batch = r_batch / r_batch.sum(-1, keepdim=True)
    >>> c_batch = torch.rand(32, 512, device='cuda', dtype=torch.float64)
    >>> c_batch = c_batch / c_batch.sum(-1, keepdim=True)
    >>> costs = solve_OT_batched(r_batch, c_batch, C, gamma_f=1024.)
"""

import math
import warnings

import torch as th

from mdot_tnt.batched import solve_OT_batched
from mdot_tnt.lowmem import solve_OT_lowmem
from mdot_tnt.mdot import mdot, preprocess_marginals
from mdot_tnt.rounding import round_altschuler, rounded_cost_altschuler

__all__ = ["solve_OT", "solve_OT_batched", "solve_OT_lowmem"]


[docs] def solve_OT( r, c, C, gamma_f=1024.0, drop_tiny=False, return_plan=False, round=True, log=False, devices=None, num_gpus=None, ): """ Solve the entropic-regularized optimal transport problem. Inputs r, c, C are required to be torch tensors. Args: r: Row marginal of shape (n,). Must sum to 1. c: Column marginal of shape (m,). Must sum to 1. C: Cost matrix of shape (n, m). Recommended to scale entries to [0, 1]. gamma_f: Temperature (inverse of the regularization weight). For many problems, stable up to 2^18. Higher values return more accurate solutions but take longer to converge. Use double precision if gamma_f is large. drop_tiny: If either marginal is known to be sparse, set this to True to drop tiny entries for a speedup. If return_plan is True, the returned plan will be in the original dimensions. return_plan: If True, return the optimal transport plan rather than the cost. round: If True, use the rounding algorithm of Altschuler et al. (2017) to (a) return a feasible plan if return_plan is True and (b) the cost of the rounded plan if return_plan is False. log: If True, additionally return a dictionary containing logs of the optimization process. devices: List of CUDA devices (``int``, ``str``, or ``torch.device``) to distribute column blocks of C across. Enables multi-GPU solving of a single OT problem by splitting the cost matrix along the column dimension and running blocked operations in parallel. Pass ``[0, 1]`` or ``[torch.device('cuda:0'), torch.device('cuda:1')]``. Mutually exclusive with ``num_gpus``. Ignored when only one device is given. num_gpus: Number of CUDA devices to use (starting from index 0). Mutually exclusive with ``devices``. Returns: Transport cost (scalar tensor) if return_plan is False, or the transport plan of shape (n, m) if return_plan is True. If log is True, returns a tuple of (result, logs_dict). """ assert all(isinstance(x, th.Tensor) for x in [r, c, C]), "r, c, and C must be torch tensors" # Multi-GPU: route through solve_OT_lowmem with column-parallel blocking. if devices is not None or num_gpus is not None: from mdot_tnt.multigpu import _resolve_devices resolved = _resolve_devices(devices, num_gpus) if resolved is not None: block_size = math.ceil(C.shape[-1] / len(resolved)) return solve_OT_lowmem( r, c, C=C, gamma_f=gamma_f, block_size=block_size, drop_tiny=drop_tiny, return_plan=return_plan, round=round, log=log, devices=resolved, ) dtype = r.dtype # Require high precision for gamma_f > 2^10 if gamma_f > 2**10 and dtype != th.float64: warnings.warn( "Switching to double precision for gamma_f > 2^10 during execution. " f"Output will be input dtype: {dtype}." ) r, c, C = r.double(), c.double(), C.double() if drop_tiny: drop_lessthan = math.log(min(r.size(-1), c.size(-1))) / (gamma_f**2) (r_, r_keep), (c_, c_keep), C_ = preprocess_marginals(r, c, C, drop_lessthan) u_, v_, gamma_f_, k_total, opt_logs = mdot(r_, c_, C_, gamma_f) u = -th.ones_like(r) * float("inf") u[r_keep] = u_ v = -th.ones_like(c) * float("inf") v[c_keep] = v_ else: u, v, gamma_f_, k_total, opt_logs = mdot(r, c, C, gamma_f) # Switch back to original dtype u, v = u.to(dtype), v.to(dtype) if return_plan: P = (u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_f_ * C).exp() if round: P = round_altschuler(P, r, c) if log: return P, opt_logs return P else: if round: cost = rounded_cost_altschuler(u, v, r, c, C, gamma_f_) else: cost = ((u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_f_ * C).exp() * C).sum() if log: return cost, opt_logs return cost