Source code for mdot_tnt.lowmem

"""
Low-memory MDOT-TNT solver using column-batched materialization.

Achieves O(nk + (n+m)d) working memory by:
  1. Never materializing the full n x m cost matrix C — cost blocks C[:, s:e]
     are computed on-the-fly from point clouds (X, Y) and a cost function.
  2. Never materializing the full n x m transport plan P — plan blocks are
     computed on-the-fly from dual variables and cost blocks.

The user controls the trade-off via block_size (k):
  - k = m: O(nm) memory, fastest (same as standard algorithm)
  - k = sqrt(m): O(n*sqrt(m)) memory (good trade-off)
  - k = 1: O(n) memory, slowest (like element-wise PyKeOps)

Supports two input modes:
  - Point cloud mode: X (n, d), Y (m, d), cost_fn — truly O(nk + (n+m)d)
  - Dense matrix mode: C (n, m) — O(nm) for C storage, O(nk) working memory

Based on:
    "A Truncated Newton Method for Optimal Transport"
    by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
"""

import math
import warnings
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch as th

from mdot_tnt.mdot import adjust_schedule, smooth_marginals
from mdot_tnt.rounding import round_altschuler

# Type alias: (start, end) -> C[:, start:end] of shape (n, end - start)
CostBlockFn = Callable[[int, int], th.Tensor]


# ---------------------------------------------------------------------------
# Cost functions for point clouds
# ---------------------------------------------------------------------------


[docs] def squared_euclidean(X: th.Tensor, Y_block: th.Tensor) -> th.Tensor: """ Squared Euclidean cost: C[i,j] = ||X[i] - Y[j]||^2. Args: X: Source points, shape (n, d). Y_block: Target point block, shape (k, d). Returns: Cost block of shape (n, k). """ # ||x - y||^2 = ||x||^2 + ||y||^2 - 2 <x, y> XX = (X * X).sum(-1, keepdim=True) # (n, 1) YY = (Y_block * Y_block).sum(-1).unsqueeze(0) # (1, k) return XX + YY - 2.0 * (X @ Y_block.T)
[docs] def euclidean(X: th.Tensor, Y_block: th.Tensor) -> th.Tensor: """ Euclidean cost: C[i,j] = ||X[i] - Y[j]||. Args: X: Source points, shape (n, d). Y_block: Target point block, shape (k, d). Returns: Cost block of shape (n, k). """ return squared_euclidean(X, Y_block).clamp(min=0.0).sqrt()
# --------------------------------------------------------------------------- # Projector # ---------------------------------------------------------------------------
[docs] class LowMemoryTruncatedNewtonProjector: """ Low-memory variant of the Truncated Newton projector. Instead of storing the full cost matrix or transport plan, cost blocks are computed on-the-fly via a user-supplied ``cost_block_fn``. All LSE reductions, Hessian-vector products, and diagonal preconditioner computations are performed block-by-block with O(nk) working memory. Args: device: PyTorch device for computations. dtype: Data type for tensors. block_size: Number of columns to process at once. cost_block_fn: ``(start, end) -> C[:, start:end]`` of shape ``(n, end-start)``. **kwargs: Additional options (debug: bool for verbose output). """ def __init__( self, device: th.device, dtype: th.dtype, block_size: int, cost_block_fn: CostBlockFn, **kwargs: Any, ) -> None: self.device = device self.dtype = dtype self.block_size = block_size self.cost_block_fn = cost_block_fn self.gamma: Union[float, th.Tensor] = 1.0 self.rho = th.zeros(1, device=device, dtype=dtype) self.debug = kwargs.get("debug", False) # -- helpers ------------------------------------------------------------- def _col_blocks(self, m: int): """Yield (start, end) index pairs for column blocks.""" for start in range(0, m, self.block_size): yield start, min(start + self.block_size, m) def _gamma_C_block(self, s: int, e: int) -> th.Tensor: """Return gamma * C[:, s:e], shape (n, e-s).""" return self.gamma * self.cost_block_fn(s, e) def _P_block(self, u: th.Tensor, v: th.Tensor, s: int, e: int) -> th.Tensor: """Compute P[:, s:e] on-the-fly from dual variables and cost block.""" return th.exp(u.unsqueeze(-1) + v[s:e].unsqueeze(-2) - self._gamma_C_block(s, e)) # -- blocked primitives -------------------------------------------------- def _blocked_lse_r(self, v: th.Tensor) -> th.Tensor: """ logsumexp(v[None, :] - gamma*C, dim=-1) via streaming logaddexp. Returns shape (n,). """ m = v.shape[-1] result: Optional[th.Tensor] = None for s, e in self._col_blocks(m): partial = th.logsumexp(v[s:e].unsqueeze(-2) - self._gamma_C_block(s, e), dim=-1) if result is None: result = partial else: result = th.logaddexp(result, partial) assert result is not None return result def _blocked_lse_c(self, u: th.Tensor, m: int) -> th.Tensor: """ logsumexp(u[:, None] - gamma*C, dim=-2) via column blocks. Each block of k columns produces k output elements (reduction is over rows), so results are concatenated directly. Returns shape (m,). """ results = [] for s, e in self._col_blocks(m): results.append(th.logsumexp(u.unsqueeze(-1) - self._gamma_C_block(s, e), dim=-2)) return th.cat(results, dim=-1) def _blocked_diag_PPc( self, u: th.Tensor, v: th.Tensor, c: th.Tensor, ) -> th.Tensor: """ Compute diag(P diag(1/c) P^T) = ((P**2) / c[None, :]).sum(-1) in column blocks. Returns shape (n,). """ m = v.shape[-1] result = th.zeros_like(u) for s, e in self._col_blocks(m): Pb = self._P_block(u, v, s, e) result += ((Pb**2) / c[s:e].unsqueeze(-2)).sum(-1) del Pb return result def _blocked_PPc_matmul( self, u: th.Tensor, v: th.Tensor, c: th.Tensor, x: th.Tensor, ) -> th.Tensor: """ P diag(1/c) P^T x in a single pass over column blocks. Uses the decomposition: P diag(1/c) P^T x = sum_b P_b diag(1/c_b) P_b^T x Returns shape (n,). """ m = v.shape[-1] result = th.zeros_like(x) for s, e in self._col_blocks(m): Pb = self._P_block(u, v, s, e) z = (x @ Pb) / c[s:e] # P_b^T x / c_b, shape (k,) result += Pb @ z # P_b z, shape (n,) del Pb, z return result def _blocked_Pc_x( self, u: th.Tensor, v: th.Tensor, c: th.Tensor, x: th.Tensor, ) -> th.Tensor: """ diag(1/c) P^T x in column blocks. Returns shape (m,). """ m = v.shape[-1] result = th.empty(m, device=self.device, dtype=self.dtype) for s, e in self._col_blocks(m): Pb = self._P_block(u, v, s, e) result[s:e] = (x @ Pb) / c[s:e] del Pb return result # -- main projection -----------------------------------------------------
[docs] def project( self, gamma: Union[float, th.Tensor], log_r: th.Tensor, log_c: th.Tensor, eps_d: Union[float, th.Tensor], u: th.Tensor, v: th.Tensor, ) -> Tuple[th.Tensor, th.Tensor, Dict[str, Any], bool]: """ Project onto the set of couplings satisfying marginal constraints. Args: gamma: Temperature (inverse regularization weight) for this projection step. log_r: Log of row marginals, shape (n,). log_c: Log of column marginals, shape (m,). eps_d: Convergence tolerance for the dual gradient norm. u: Initial row dual variables, shape (n,). v: Initial column dual variables, shape (m,). Returns: u: Updated row dual variables. v: Updated column dual variables. logs: Dictionary with optimization statistics. success: Whether projection converged successfully. """ self.gamma = gamma m = v.shape[-1] logs: Dict[str, Any] = { "errs": [], "ls_func_cnt": 0, "chisinkhorn_steps": 0, "newtonsolve_steps": 0, "deltas": [], "all_newtonsolve_steps": [], } success_fn = lambda err_: err_ < 10 * eps_d r = log_r.exp() c = log_c.exp() # Blocked LSE operations — O(nk) working memory per call. self.LSE_r = lambda v_: self._blocked_lse_r(v_) self.LSE_c = lambda u_: self._blocked_lse_c(u_, m) log_c_P = v + self.LSE_c(u) v += log_c - log_c_P log_r_P = u + self.LSE_r(v) k = 8 u, v, log_r_P, err, k_ = self.chi_sinkhorn(u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5)) r_P = log_r_P.exp() logs["errs"].append(err) logs["chisinkhorn_steps"] = k_ k += k_ num_iter = 0 while err > eps_d: num_iter += 1 beta = 0.5 eta_k = th.max(err, 0.9 * (eps_d / err)) grad_k = r_P - r self.rho = th.max(th.zeros_like(self.rho), self.rho) diag_PPc = self._blocked_diag_PPc(u, v, c) k += 8 delta_u, delta_v, matmul_cnt, rho, pcg_success = self.newton_solve( u, v, c, diag_PPc, grad_k, r_P, err, beta, eta_k, maxIter=5000, ) if not pcg_success: k += matmul_cnt logs["n_iter"] = k msg = ( "PCG did not converge. TruncatedNewton returning with " f"success={success_fn(err)}" ) warnings.warn(msg) return u, v, logs, success_fn(err) self.rho = th.max(th.zeros_like(self.rho), 1.0 - (1.0 - rho) * 4.0) k += matmul_cnt logs["newtonsolve_steps"] += matmul_cnt alpha = th.ones_like(self.rho) log_c_P = v + alpha * delta_v + self.LSE_c(u + alpha * delta_u) k += 4 linear_decr = -(grad_k * delta_u).sum(-1, keepdim=True) if not linear_decr > 0: logs["n_iter"] = k msg = ( "Linear decrease condition not satisfied. TruncatedNewton " f"returning with success={success_fn(err)}" ) warnings.warn(msg) return u, v, logs, success_fn(err) armijo = log_c_P.exp().sum(-1, keepdim=True) - 1 <= 0.99 * alpha * linear_decr while not armijo: alpha *= 0.5 if alpha < 1e-9: logs["n_iter"] = k msg = ( "Line search did not converge. TruncatedNewton " f"returning with success={success_fn(err)}" ) warnings.warn(msg) return u, v, logs, success_fn(err) log_c_P = v + alpha * delta_v + self.LSE_c(u + alpha * delta_u) k += 4 logs["ls_func_cnt"] += 4 armijo = log_c_P.exp().sum(-1, keepdim=True) - 1 <= 0.99 * alpha * linear_decr u += alpha * delta_u v += alpha * delta_v err_before_sk = (c - log_c_P.exp()).abs().sum(-1) err_before_sk += (r - (u + self.LSE_r(v)).exp()).abs().sum(-1) v += log_c - log_c_P log_r_P = u + self.LSE_r(v) k += 4 u, v, log_r_P, err, k_ = self.chi_sinkhorn( u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5) ) r_P = log_r_P.exp() logs["chisinkhorn_steps"] += k_ k += k_ logs["errs"].append(err) logs["deltas"].append( th.min((logs["errs"][-2] - err_before_sk) / ((1 - eta_k) * logs["errs"][-2])).item() ) if u.isnan().any() or v.isnan().any(): raise ValueError("NaNs encountered in u or v") logs["n_iter"] = k delta_u = log_r - log_r_P u += delta_u return u, v, logs, True
# -- chi-squared Sinkhorn ------------------------------------------------
[docs] def chi_sinkhorn( self, u: th.Tensor, v: th.Tensor, log_r: th.Tensor, log_c: th.Tensor, log_r_P: th.Tensor, eps_chi: Union[float, th.Tensor], maxOps: float = float("inf"), ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, int]: """Chi-squared Sinkhorn iterations using blocked LSE.""" k = 0 r = log_r.exp() err = (r - log_r_P.exp()).norm(p=1, dim=-1) r_P = log_r_P.exp() chi_squared = ((r - r_P) ** 2 / r_P).sum(-1) while chi_squared > eps_chi and k < maxOps: u += log_r - log_r_P log_c_P = v + self.LSE_c(u) v += log_c - log_c_P log_r_P = u + self.LSE_r(v) r_P = log_r_P.exp() err = (r - r_P).norm(p=1, dim=-1) chi_squared = ((r - r_P) ** 2 / r_P).sum(-1) k += 8 if k >= maxOps: raise ValueError(f"Chi-Sinkhorn did not converge in maxIter={maxOps} steps") return u, v, log_r_P, err, k
# -- Newton solve --------------------------------------------------------
[docs] def newton_solve( self, u: th.Tensor, v: th.Tensor, c: th.Tensor, diag_PPc: th.Tensor, grad_k: th.Tensor, r_P: th.Tensor, err: th.Tensor, beta: float = 0.5, eta_k: Union[float, th.Tensor] = 0.5, maxIter: int = 500, ) -> Tuple[th.Tensor, th.Tensor, int, th.Tensor, bool]: """ Newton solve using blocked Hessian-vector products. Transport plan blocks are recomputed on-the-fly for each mat-vec product from the (unchanged) dual variables u, v and the cost function. """ rho = self.rho tol = err * eta_k def mml(x_): return self._blocked_PPc_matmul(u, v, c, x_) M = lambda rho_: r_P - rho_ * diag_PPc M_rho = M(th.ones_like(self.rho)) M_rho[M_rho <= 0] = M_rho[M_rho > 0].min() x0 = -grad_k / M_rho PPc_x0 = mml(x0) matmul_cnt = 2 r_P_x0 = r_P * x0 x = x0.clone() PPc_x = PPc_x0.clone() r_P_x = r_P_x0.clone() res_true = r_P_x0 - PPc_x + grad_k linear_decr = (x * -grad_k).sum(-1) if linear_decr <= 0: raise ValueError("Linear decrease condition not satisfied") r_true_norm = res_true.norm(p=1, dim=-1) best_sol = x.clone() best_r_true_norm = r_true_norm.clone() done = False success = True while best_r_true_norm > tol: best_sol[r_true_norm < best_r_true_norm] = x[r_true_norm < best_r_true_norm] best_r_true_norm = th.min(r_true_norm, best_r_true_norm) rho[r_true_norm > tol] = 1.0 - (1.0 - rho[r_true_norm > tol]) * 0.25 M_rho = M(rho) if matmul_cnt > 0: x = x0.clone() PPc_x = PPc_x0.clone() r_P_x = r_P_x0.clone() Fr_x = r_P_x - rho * PPc_x res = Fr_x + grad_k res_true = r_P_x - PPc_x + grad_k r_true_norm = res_true.norm(p=1, dim=-1) best_r_true_norm = th.min(r_true_norm, best_r_true_norm) linear_decr = (x * -grad_k).sum(-1) if (best_r_true_norm < tol).all() and (linear_decr > 0).all(): break y = res / M_rho p = -y.clone() ry_old = (res * y).sum(-1, keepdim=True) r_norm = res.norm(p=1, dim=-1) while (r_norm > 0.5 * (1 - beta) * tol)[best_r_true_norm > tol].any(): PPc_p = mml(p) matmul_cnt += 2 Fr_p = (r_P * p) - rho * PPc_p quad = (Fr_p * p).sum(-1, keepdim=True) if (quad <= 0)[best_r_true_norm > tol].any(): warnings.warn( "Warning: negative curvature encountered in CG. " "Returning best solution. Residual norm less than " f"error: {(best_r_true_norm < err).item()}" ) x = best_sol.clone() done = True success = best_r_true_norm < err warnings.warn("Resetting discount factor rho = 0") rho = th.zeros_like(self.rho) break alpha = ry_old / quad x += alpha * p res += alpha * Fr_p r_norm = res.norm(p=1, dim=-1) if ( th.isnan(r_norm)[best_r_true_norm > tol].any() or th.isinf(r_norm)[best_r_true_norm > tol].any() ): raise ValueError("NaNs or infs encountered in r_norm") PPc_x += alpha * PPc_p r_P_x = r_P * x res_true = r_P_x - PPc_x + grad_k r_true_norm = res_true.norm(p=1, dim=-1) best_sol[r_true_norm < best_r_true_norm] = x[r_true_norm < best_r_true_norm] best_r_true_norm = th.min(r_true_norm, best_r_true_norm) linear_decr = (x * -grad_k).sum(-1) if (best_r_true_norm <= tol).all() and (linear_decr > 0).all(): done = True success = True break if matmul_cnt > 2 * maxIter: warnings.warn("PCG did not converge.") done = True success = False break y = res / M_rho ry_new = (res * y).sum(-1, keepdim=True) p = -y + (ry_new / ry_old) * p ry_old = ry_new.clone() if done: break if r_true_norm <= tol: success = True x = best_sol Pc_x = self._blocked_Pc_x(u, v, c, x) matmul_cnt += 1 return x, -Pc_x, matmul_cnt, rho, success
# --------------------------------------------------------------------------- # Blocked cost computation # ---------------------------------------------------------------------------
[docs] def blocked_rounded_cost( u: th.Tensor, v: th.Tensor, r: th.Tensor, c: th.Tensor, cost_block_fn: CostBlockFn, m: int, gamma: Union[float, th.Tensor], block_size: int, ) -> th.Tensor: """ Compute the rounded transport cost without materializing C or P. Mirrors ``rounded_cost_altschuler`` but accesses the cost matrix only through ``cost_block_fn``. Working memory is O(n * block_size). Args: u: Dual variable for rows, shape (n,). v: Dual variable for columns, shape (m,). r: Row marginal, shape (n,). c: Column marginal, shape (m,). cost_block_fn: ``(s, e) -> C[:, s:e]``. m: Number of columns. gamma: Temperature. block_size: Number of columns per block. Returns: Transport cost (scalar tensor). """ def col_blocks(): for s in range(0, m, block_size): yield s, min(s + block_size, m) def blocked_lse_r(): result = None for s, e in col_blocks(): Cb = cost_block_fn(s, e) partial = th.logsumexp(v[s:e].unsqueeze(-2) - gamma * Cb, dim=-1) if result is None: result = partial else: result = th.logaddexp(result, partial) return result def blocked_lse_c(): parts = [] for s, e in col_blocks(): Cb = cost_block_fn(s, e) parts.append(th.logsumexp(u.unsqueeze(-1) - gamma * Cb, dim=-2)) return th.cat(parts, dim=-1) # Step 1: Row rounding r_P_log = u + blocked_lse_r() delta_u = th.min(r.log() - r_P_log, th.zeros_like(r)) u = u + delta_u # Step 2: Column rounding c_P_log = v + blocked_lse_c() delta_v = th.min(c.log() - c_P_log, th.zeros_like(c)) v = v + delta_v # Step 3: Residual errors for rank-1 correction r_P_log = u + blocked_lse_r() r_P = r_P_log.exp() err_r = r - r_P err_r = err_r / (err_r.norm(p=1, dim=-1, keepdim=True) + 1e-30) c_P_log = v + blocked_lse_c() c_P = c_P_log.exp() err_c = c - c_P # Step 4: cost = sum_{i,j} P_{ij} C_{ij} # = sum_{i,j} exp(u_i + v_j - gamma*C_{ij}) * C_{ij} log_cost: Optional[th.Tensor] = None for s, e in col_blocks(): Cb = cost_block_fn(s, e) log_block = u.unsqueeze(-1) + v[s:e].unsqueeze(-2) - gamma * Cb + Cb.log() del Cb partial = th.logsumexp(log_block.reshape(-1), dim=0) del log_block if log_cost is None: log_cost = partial else: log_cost = th.logaddexp(log_cost, partial) assert log_cost is not None cost = log_cost.exp() # Step 5: Rank-1 correction: err_r^T C err_c correction = th.zeros(1, device=u.device, dtype=u.dtype) for s, e in col_blocks(): Cb = cost_block_fn(s, e) correction += (err_r @ Cb) @ err_c[s:e] del Cb cost += correction.squeeze() return cost
[docs] def blocked_transport_cost( u: th.Tensor, v: th.Tensor, cost_block_fn: CostBlockFn, m: int, gamma: Union[float, th.Tensor], block_size: int, ) -> th.Tensor: """ Compute unrounded transport cost without materializing C or P. cost = sum_{i,j} exp(u_i + v_j - gamma * C_{ij}) * C_{ij} Args: u: Dual variable for rows, shape (n,). v: Dual variable for columns, shape (m,). cost_block_fn: ``(s, e) -> C[:, s:e]``. m: Number of columns. gamma: Temperature. block_size: Number of columns per block. Returns: Transport cost (scalar tensor). """ log_cost: Optional[th.Tensor] = None for s in range(0, m, block_size): e = min(s + block_size, m) Cb = cost_block_fn(s, e) log_block = u.unsqueeze(-1) + v[s:e].unsqueeze(-2) - gamma * Cb + Cb.log() del Cb partial = th.logsumexp(log_block.reshape(-1), dim=0) del log_block if log_cost is None: log_cost = partial else: log_cost = th.logaddexp(log_cost, partial) assert log_cost is not None return log_cost.exp()
# --------------------------------------------------------------------------- # MDOT solver (low-memory variant) # ---------------------------------------------------------------------------
[docs] def mdot_lowmem( r: th.Tensor, c: th.Tensor, cost_block_fn: CostBlockFn, gamma_f: float, block_size: int, gamma_i: float = 16, p: float = 1.5, q: float = 2.0, devices: Optional[Any] = None, ) -> Tuple[th.Tensor, th.Tensor, float, int, Dict[str, Any]]: """ Solve entropic-regularized OT using MDOT with low-memory projection. Args: r: Row marginal, shape (n,). c: Column marginal, shape (m,). cost_block_fn: ``(s, e) -> C[:, s:e]`` — computes cost blocks lazily. gamma_f: Final temperature (inverse regularization weight). block_size: Number of columns per block for the projector. gamma_i: Initial temperature. p: Exponent for the epsilon schedule. q: Temperature annealing factor. devices: Optional list of ``torch.device`` objects for multi-GPU execution (from :func:`mdot_tnt.multigpu._resolve_devices`). When provided, column blocks are distributed across these devices and blocked primitives run in parallel. Returns: u: Row dual variables, shape (n,). v: Column dual variables, shape (m,). gamma: Final temperature achieved. k_total: Total number of O(n^2) primitive operations. logs: Optimization statistics. """ if devices is not None: from mdot_tnt.multigpu import MultiGPULowMemoryProjector m = c.shape[0] projector: Any = MultiGPULowMemoryProjector( device=r.device, dtype=r.dtype, cost_block_fn=cost_block_fn, m=m, c=c, devices=devices, ) else: projector = LowMemoryTruncatedNewtonProjector( device=r.device, dtype=r.dtype, block_size=block_size, cost_block_fn=cost_block_fn, ) H_r = -(r * (r + 1e-30).log()).sum(-1) H_c = -(c * (c + 1e-30).log()).sum(-1) H_min = th.min(H_r, H_c) eps_fn = lambda g_: H_min / (g_**p) logs: Dict[str, Any] = { "proj_logs": [], "eps": [], } t = 1 done = False gamma = min(gamma_i, gamma_f) gammas = [0.0, gamma] while not done: done = abs(gamma - gamma_f) < 1e-5 eps_d = eps_fn(gamma) r_hat, c_hat = smooth_marginals(r, c, eps_d / 2, w_r=0.9, w_c=0.1) if t == 1: u_init, v_init = r_hat.log(), c_hat.log() u_cur, v_cur = u_init.clone(), v_init.clone() u_prev, v_prev = u_cur.clone(), v_cur.clone() u_cur, v_cur, proj_log, success = projector.project( gamma, r_hat.log(), c_hat.log(), eps_d / 2, u_init, v_init ) logs["proj_logs"].append(proj_log) if not success: warnings.warn( f"Projection failed. Returning result at the last temperature: {1 / gammas[-2]:.4e}" ) u_cur = u_prev.clone() v_cur = v_prev.clone() gammas = gammas[:-1] break q = adjust_schedule(q, proj_log["deltas"]) gamma = min(gamma * q, gamma_f) if not done: u_init = u_cur + (u_cur - u_prev) * (gamma - gammas[-1]) / (gammas[-1] - gammas[-2]) v_init = v_cur + (v_cur - v_prev) * (gamma - gammas[-1]) / (gammas[-1] - gammas[-2]) gammas.append(gamma) t += 1 k_total = sum([log["n_iter"] for log in logs["proj_logs"]]) k_total += t - 1 logs["success"] = success logs["gammas"] = gammas return u_cur, v_cur, gammas[-1], k_total, logs
# --------------------------------------------------------------------------- # Marginal preprocessing (inlined to avoid depending on a dense C) # --------------------------------------------------------------------------- def _preprocess_marginal(m: th.Tensor, eps: float) -> Tuple[th.Tensor, th.Tensor]: """Drop smallest entries whose cumulative mass is below ``eps``.""" m_sorted, m_idx = th.sort(m, dim=-1, descending=False) m_cumsum = th.cumsum(m_sorted, dim=-1) m_keep = m_idx[m_cumsum > eps] m_new = m[m_keep] mass_removed = 1 - m_new.sum(-1) m_new = m_new + mass_removed / m_new.size(-1) return m_new, m_keep # --------------------------------------------------------------------------- # Top-level API # ---------------------------------------------------------------------------
[docs] def solve_OT_lowmem( r: th.Tensor, c: th.Tensor, C: Optional[th.Tensor] = None, X: Optional[th.Tensor] = None, Y: Optional[th.Tensor] = None, cost_fn: Optional[Callable[[th.Tensor, th.Tensor], th.Tensor]] = None, gamma_f: float = 1024.0, block_size: Optional[int] = None, drop_tiny: bool = False, return_plan: bool = False, round: bool = True, log: bool = False, devices: Optional[Any] = None, num_gpus: Optional[int] = None, ): """ Solve entropic-regularized OT with O(nk) working memory. Accepts either a dense cost matrix **or** point clouds with a cost function. In point-cloud mode the cost matrix is never materialised, giving true O(nk + (n+m)d) memory. Args: r: Row marginal, shape (n,). Must sum to 1. c: Column marginal, shape (m,). Must sum to 1. C: Dense cost matrix, shape (n, m). Mutually exclusive with X/Y. Recommended to scale entries to [0, 1]. X: Source points, shape (n, d). Use with Y and optionally cost_fn. Y: Target points, shape (m, d). Use with X. cost_fn: ``(X, Y_block) -> C_block`` of shape (n, k). Defaults to ``squared_euclidean``. Only used when X/Y are provided. gamma_f: Temperature (inverse regularization weight). block_size: Columns per block. Controls memory / compute trade-off: - None or m: fastest, O(nm) or O(nk) memory - sqrt(m): good trade-off - 1: minimum memory (slowest) drop_tiny: Drop tiny marginal entries for speedup with sparse marginals. return_plan: If True, return the full transport plan. Note: the plan is inherently O(nm); in point-cloud mode it is built block-by-block so that C and P are never both in memory simultaneously. round: Use Altschuler rounding for feasible plans / costs. log: Additionally return optimization logs. devices: List of CUDA devices (``int``, ``str``, or ``torch.device``) to distribute column blocks across. Pass ``[0, 1]`` or ``[torch.device('cuda:0'), torch.device('cuda:1')]``. Mutually exclusive with ``num_gpus``. When a single device is given the standard single-GPU solver is used. num_gpus: Number of CUDA devices to use (starting from index 0). Mutually exclusive with ``devices``. Returns: Transport cost (scalar) if return_plan is False, or the transport plan (n, m) if return_plan is True. If log is True, returns (result, logs). """ # -- multi-GPU device resolution ----------------------------------------- from mdot_tnt.multigpu import _resolve_devices resolved_devices = _resolve_devices(devices, num_gpus) # -- input validation ---------------------------------------------------- have_C = C is not None have_XY = X is not None and Y is not None if not (have_C ^ have_XY): raise ValueError( "Provide exactly one of: C (dense cost matrix) or X and Y " f"(point clouds). Got C={C is not None}, X={X is not None}, Y={Y is not None}." ) # -- dtype promotion ----------------------------------------------------- dtype = r.dtype if gamma_f > 2**10 and dtype != th.float64: warnings.warn( "Switching to double precision for gamma_f > 2^10 during " f"execution. Output will be input dtype: {dtype}." ) r, c = r.double(), c.double() if have_C: assert C is not None C = C.double() else: assert X is not None and Y is not None X, Y = X.double(), Y.double() # -- build cost_block_fn ------------------------------------------------- cost_block_fn: CostBlockFn if have_C: assert C is not None m = C.shape[-1] cost_block_fn = lambda s, e: C[:, s:e] else: assert X is not None and Y is not None m = Y.shape[0] _cf = cost_fn if cost_fn is not None else squared_euclidean cost_block_fn = lambda s, e: _cf(X, Y[s:e]) if block_size is None: if resolved_devices is not None: # In multi-GPU mode the projector splits m across K GPUs. Use the # same per-GPU column count for the final cost computation so that # blocked_rounded_cost / blocked_transport_cost never try to # materialise the full cost matrix on the primary device. block_size = math.ceil(m / len(resolved_devices)) else: block_size = m # -- optional marginal sparsification ------------------------------------ if drop_tiny: drop_lessthan = math.log(min(r.size(-1), c.size(-1))) / (gamma_f**2) r_, r_keep = _preprocess_marginal(r, drop_lessthan) c_, c_keep = _preprocess_marginal(c, drop_lessthan) cbf_: CostBlockFn if have_C: assert C is not None C_ = C[r_keep][:, c_keep] cbf_ = lambda s, e: C_[:, s:e] else: assert X is not None and Y is not None X_ = X[r_keep] Y_ = Y[c_keep] cbf_ = lambda s, e: _cf(X_, Y_[s:e]) u_, v_, gamma_f_, k_total, opt_logs = mdot_lowmem( r_, c_, cbf_, gamma_f, block_size, devices=resolved_devices ) u = -th.ones_like(r) * float("inf") u[r_keep] = u_ v = -th.ones_like(c) * float("inf") v[c_keep] = v_ else: u, v, gamma_f_, k_total, opt_logs = mdot_lowmem( r, c, cost_block_fn, gamma_f, block_size, devices=resolved_devices ) u, v = u.to(dtype), v.to(dtype) # -- build cost_block_fn at final gamma (for cost / plan) ---------------- # If drop_tiny was used, cost_block_fn should operate on the full # (un-subsetted) dimensions — it was defined above before drop_tiny. if return_plan: # Build the plan block-by-block so C and P are never both in memory. blocks = [] for s in range(0, m, block_size): e = min(s + block_size, m) Cb = cost_block_fn(s, e) Pb = th.exp(u.unsqueeze(-1) + v[s:e].unsqueeze(-2) - gamma_f_ * Cb) del Cb blocks.append(Pb) P = th.cat(blocks, dim=-1) del blocks if round: P = round_altschuler(P, r, c) return (P, opt_logs) if log else P else: if round: cost = blocked_rounded_cost(u, v, r, c, cost_block_fn, m, gamma_f_, block_size) else: cost = blocked_transport_cost(u, v, cost_block_fn, m, gamma_f_, block_size) return (cost, opt_logs) if log else cost