Source code for mdot_tnt.mdot

"""Core MDOT solver using truncated Newton projection."""

import warnings
from typing import Any, Dict, List, Tuple, Union

import torch as th

from mdot_tnt.truncated_newton import TruncatedNewtonProjector


[docs] def preprocess_marginals( r: th.Tensor, c: th.Tensor, C: th.Tensor, eps: float ) -> Tuple[Tuple[th.Tensor, th.Tensor], Tuple[th.Tensor, th.Tensor], th.Tensor]: """ Drop the smallest marginal entries whose cumulative sum is below a threshold. Args: r: The row marginal of shape (n,). c: The column marginal of shape (m,). C: The cost matrix of shape (n, m). eps: The threshold for the cumulative sum of the marginal entries to be dropped. Returns: A tuple containing: - (r_new, r_keep): The new row marginal and indices of kept entries. - (c_new, c_keep): The new column marginal and indices of kept entries. - C: The cost matrix with corresponding rows and columns dropped. """ def preprocess_marginal(m: th.Tensor, eps: float) -> Tuple[th.Tensor, th.Tensor]: m_sorted, m_idx = th.sort(m, dim=-1, descending=False) m_cumsum = th.cumsum(m_sorted, dim=-1) m_keep = m_idx[m_cumsum > eps] m_new = m[m_keep] mass_removed = 1 - m_new.sum(-1) m_new = m_new + mass_removed / m_new.size(-1) return m_new, m_keep r_new, r_keep = preprocess_marginal(r, eps) c_new, c_keep = preprocess_marginal(c, eps) print( f"Dropped {r.size(-1) - r_new.size(-1)} entries from r and {c.size(-1) - c_new.size(-1)} entries from c." ) C = C[r_keep][:, c_keep] return (r_new, r_keep), (c_new, c_keep), C
[docs] def smooth_marginals( r: th.Tensor, c: th.Tensor, eps: th.Tensor, w_r: float = 0.5, w_c: float = 0.5, ) -> Tuple[th.Tensor, th.Tensor]: """ Smooth the marginals by adding a small amount of uniform mass to each entry. Args: r: The row marginal of shape (n,). c: The column marginal of shape (m,). eps: The amount of mass to add to each entry. w_r: The weight for the row marginal. w_c: The weight for the column marginal. Returns: A tuple (r_hat, c_hat) of smoothed marginals with total TV distance at most eps from the original marginals. """ assert w_r + w_c == 1, "w_r and w_c must sum to 1" eps = eps.clamp(max=1.0).unsqueeze(-1) r_hat = (1 - w_r * eps) * r + w_r * eps * th.ones_like(r) / r.size(-1) c_hat = (1 - w_c * eps) * c + w_c * eps * th.ones_like(c) / c.size(-1) return r_hat, c_hat
[docs] def adjust_schedule(q: float, deltas: Union[List[float], None] = None) -> float: """ Adjust the temperature annealing schedule based on the success of the Truncated Newton method. Args: q: The current temperature annealing schedule adjustment factor. deltas: The list of deltas from the Truncated Newton method; see Sec. 3.3 of Kemertas et al. (2025). Returns: The new temperature annealing schedule adjustment factor. """ if deltas is None: return q deltas = deltas + [1.0] # If deltas is empty, we assume that the first iteration was successful delta_min = min(deltas) if delta_min < 0.5: q = q**0.5 elif delta_min > 0.9: q = q**2 return q
[docs] def mdot( r: th.Tensor, c: th.Tensor, C: th.Tensor, gamma_f: float, gamma_i: float = 16, p: float = 1.5, q: float = 2.0, ) -> Tuple[th.Tensor, th.Tensor, float, int, Dict[str, Any]]: """ Solve the entropic-regularized optimal transport problem using the MDOT method. This implements the MDOT method introduced in the paper: "Efficient and Accurate Optimal Transport with Mirror Descent and Conjugate Gradients" by Mete Kemertas, Allan D. Jepson and Amir-massoud Farahmand. URL: https://arxiv.org/abs/2307.08507 Here, we use the Truncated Newton method for projection. Args: r: The first marginal of shape (n,). c: The second marginal of shape (m,). C: The cost matrix of shape (n, m). Recommended to scale entries to [0, 1]. gamma_f: The final temperature (inverse of the regularization weight). gamma_i: The initial temperature. p: The exponent for the epsilon function, used to determine the stopping criterion for the dual gradient. q: The temperature annealing (or mirror descent step size) schedule adjustment factor. Returns: A tuple containing: - u: The row dual variables of shape (n,). - v: The column dual variables of shape (m,). - gamma: The final temperature achieved. - k_total: The total number of O(n^2) primitive operations. - logs: Dictionary with optimization statistics. """ projector = TruncatedNewtonProjector(device=C.device, dtype=C.dtype) H_r = -(r * (r + 1e-30).log()).sum(-1) H_c = -(c * (c + 1e-30).log()).sum(-1) H_min = th.min(H_r, H_c) eps_fn = lambda g_: H_min / (g_**p) logs: Dict[str, Any] = { "proj_logs": [], "eps": [], } t = 1 done = False gamma = min(gamma_i, gamma_f) gammas = [0.0, gamma] while not done: done = abs(gamma - gamma_f) < 1e-5 # Check if gamma == gamma_f (modulo rounding errors) eps_d = eps_fn(gamma) r_hat, c_hat = smooth_marginals(r, c, eps_d / 2, w_r=0.9, w_c=0.1) if t == 1: u_init, v_init = r_hat.log(), c_hat.log() u_cur, v_cur = u_init.clone(), v_init.clone() u_prev, v_prev = u_cur.clone(), v_cur.clone() gamma_C = gamma * C u_cur, v_cur, proj_log, success = projector.project( gamma_C, r_hat.log(), c_hat.log(), eps_d / 2, u_init, v_init ) logs["proj_logs"].append(proj_log) if not success: warnings.warn( f"Projection failed. Returning result at the last temperature: {1 / gammas[-2]:.4e}" ) u_cur = u_prev.clone() v_cur = v_prev.clone() gammas = gammas[:-1] break q = adjust_schedule(q, proj_log["deltas"]) gamma = min(gamma * q, gamma_f) if not done: # Generate warm-started initializations for the next iteration. u_init = u_cur + (u_cur - u_prev) * (gamma - gammas[-1]) / (gammas[-1] - gammas[-2]) v_init = v_cur + (v_cur - v_prev) * (gamma - gammas[-1]) / (gammas[-1] - gammas[-2]) gammas.append(gamma) t += 1 k_total = sum([log["n_iter"] for log in logs["proj_logs"]]) k_total += t - 1 logs["success"] = success logs["gammas"] = gammas return u_cur, v_cur, gammas[-1], k_total, logs