Source code for mdot_tnt.truncated_newton

"""Truncated Newton projector for the MDOT algorithm."""

import warnings
from typing import Any, Callable, Dict, Tuple, Union

import torch as th


[docs] class TruncatedNewtonProjector: """ Truncated Newton projector for the MDOT algorithm. Projects onto the set of couplings satisfying marginal constraints using a preconditioned conjugate gradient method within a Newton framework. """ def __init__(self, device: th.device, dtype: th.dtype, **kwargs: Any) -> None: """ Initialize the projector. Args: device: PyTorch device for computations. dtype: Data type for tensors. **kwargs: Additional options (debug: bool for verbose output). """ self.device = device self.rho = th.zeros(1, device=device, dtype=dtype) self.debug = kwargs.get("debug", False) self.LSE_r: Callable[[th.Tensor], th.Tensor] self.LSE_c: Callable[[th.Tensor], th.Tensor]
[docs] def project( self, gamma_C: th.Tensor, log_r: th.Tensor, log_c: th.Tensor, eps_d: Union[float, th.Tensor], u: th.Tensor, v: th.Tensor, ) -> Tuple[th.Tensor, th.Tensor, Dict[str, Any], bool]: """ Project onto the set of couplings that satisfy the marginal constraints. Args: gamma_C: The cost matrix scaled by gamma, shape (n, m). log_r: Log of row marginals, shape (n,). log_c: Log of column marginals, shape (m,). eps_d: Convergence tolerance for the dual gradient norm. u: Initial row dual variables, shape (n,). v: Initial column dual variables, shape (m,). Returns: u: Updated row dual variables. v: Updated column dual variables. logs: Dictionary with optimization statistics. success: Whether projection converged successfully. """ logs: Dict[str, Any] = { "errs": [], "ls_func_cnt": 0, "chisinkhorn_steps": 0, "newtonsolve_steps": 0, "deltas": [], # Ratios of actual to theoretically predicted (ideal) reduction in gradient norm. "all_newtonsolve_steps": [], } # In case of errors or issues, 10 times the tolerance level is considered # a good enough solution to keep MDOT going. success_fn = lambda err_: err_ < 10 * eps_d r = log_r.exp() c = log_c.exp() # Each LSE operation costs 4 * n^2 operations. self.LSE_r = lambda v_: th.logsumexp(v_.unsqueeze(-2) - gamma_C, dim=-1) self.LSE_c = lambda u_: th.logsumexp(u_.unsqueeze(-1) - gamma_C, dim=-2) log_c_P = v + self.LSE_c(u) v += log_c - log_c_P # Ensure c=c(P) log_r_P = u + self.LSE_r(v) k = 8 u, v, log_r_P, err, k_ = self.chi_sinkhorn(u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5)) r_P = log_r_P.exp() logs["errs"].append(err) logs["chisinkhorn_steps"] = k_ k += k_ num_iter = 0 while err > eps_d: num_iter += 1 beta = 0.5 eta_k = th.max(err, 0.9 * (eps_d / err)) grad_k = r_P - r self.rho = th.max(th.zeros_like(self.rho), self.rho) P = th.exp(u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_C) diag_PPc = ((P**2) / c.unsqueeze(-2)).sum(-1) k += 8 delta_u, delta_v, matmul_cnt, rho, pcg_success = self.newton_solve( P, c, diag_PPc, grad_k, r_P, err, beta, eta_k, maxIter=5000 ) del P # Free up memory if not pcg_success: k += matmul_cnt logs["n_iter"] = k msg = f"PCG did not converge. TruncatedNewton returning with success={success_fn(err)}" warnings.warn(msg) return u, v, logs, success_fn(err) self.rho = th.max(th.zeros_like(self.rho), 1.0 - (1.0 - rho) * 4.0) k += matmul_cnt logs["newtonsolve_steps"] += matmul_cnt alpha = th.ones_like(self.rho) log_c_P = v + alpha * delta_v + self.LSE_c(u + alpha * delta_u) k += 4 linear_decr = -(grad_k * delta_u).sum(-1, keepdim=True) if not linear_decr > 0: logs["n_iter"] = k msg = f"Linear decrease condition not satisfied. TruncatedNewton returning with success={success_fn(err)}" warnings.warn(msg) return u, v, logs, success_fn(err) armijo = log_c_P.exp().sum(-1, keepdim=True) - 1 <= 0.99 * alpha * linear_decr while not armijo: # Check armijo condition for batch elements where err > eps_d alpha *= 0.5 if alpha < 1e-9: logs["n_iter"] = k msg = f"Line search did not converge. TruncatedNewton returning with success={success_fn(err)}" warnings.warn(msg) return u, v, logs, success_fn(err) log_c_P = v + alpha * delta_v + self.LSE_c(u + alpha * delta_u) k += 4 logs["ls_func_cnt"] += 4 armijo = log_c_P.exp().sum(-1, keepdim=True) - 1 <= 0.99 * alpha * linear_decr u += alpha * delta_u v += alpha * delta_v # The below error (before the Sinkhorn update) is used # to measure the progress of the algorithm with TruncatedNewton steps. err_before_sk = (c - log_c_P.exp()).abs().sum(-1) err_before_sk += (r - (u + self.LSE_r(v)).exp()).abs().sum(-1) # Sinkhorn update to ensure c=c(P). v += log_c - log_c_P log_r_P = u + self.LSE_r(v) k += 4 u, v, log_r_P, err, k_ = self.chi_sinkhorn( u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5) ) r_P = log_r_P.exp() logs["chisinkhorn_steps"] += k_ k += k_ logs["errs"].append(err) logs["deltas"].append( th.min((logs["errs"][-2] - err_before_sk) / ((1 - eta_k) * logs["errs"][-2])).item() ) if u.isnan().any() or v.isnan().any(): raise ValueError("NaNs encountered in u or v") logs["n_iter"] = k # Since we already computed log_r_P, we can use it to perform one last Sinkhorn update on rows. delta_u = log_r - log_r_P u += delta_u return u, v, logs, True
[docs] def chi_sinkhorn( self, u: th.Tensor, v: th.Tensor, log_r: th.Tensor, log_c: th.Tensor, log_r_P: th.Tensor, eps_chi: Union[float, th.Tensor], maxOps: float = float("inf"), ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, int]: k = 0 r = log_r.exp() err = (r - log_r_P.exp()).norm(p=1, dim=-1) r_P = log_r_P.exp() chi_squared = ((r - r_P) ** 2 / r_P).sum(-1) while chi_squared > eps_chi and k < maxOps: u += log_r - log_r_P log_c_P = v + self.LSE_c(u) v += log_c - log_c_P log_r_P = u + self.LSE_r(v) r_P = log_r_P.exp() err = (r - r_P).norm(p=1, dim=-1) chi_squared = ((r - r_P) ** 2 / r_P).sum(-1) k += 8 if k >= maxOps: raise ValueError(f"Chi-Sinkhorn did not converge in maxIter={maxOps} steps") return u, v, log_r_P, err, k
[docs] def newton_solve( self, P: th.Tensor, c: th.Tensor, diag_PPc: th.Tensor, grad_k: th.Tensor, r_P: th.Tensor, err: th.Tensor, beta: float = 0.5, eta_k: Union[float, th.Tensor] = 0.5, maxIter: int = 500, ) -> Tuple[th.Tensor, th.Tensor, int, th.Tensor, bool]: rho = self.rho tol = err * eta_k matmul_PPc = lambda x_: ( P @ ((x_.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1)) ).squeeze(-1) # mml = th.compile(matmul_PPc) mml = matmul_PPc M = lambda rho_: r_P - rho_ * diag_PPc # Diagonal preconditioner M_rho = M(th.ones_like(self.rho)) M_rho[M_rho <= 0] = M_rho[M_rho > 0].min() x0 = -grad_k / M_rho PPc_x0 = mml(x0) matmul_cnt = 2 r_P_x0 = r_P * x0 x = x0.clone() PPc_x = PPc_x0.clone() r_P_x = r_P_x0.clone() res_true = r_P_x0 - PPc_x + grad_k linear_decr = (x * -grad_k).sum(-1) if linear_decr <= 0: raise ValueError("Linear decrease condition not satisfied") r_true_norm = res_true.norm(p=1, dim=-1) best_sol = x.clone() best_r_true_norm = r_true_norm.clone() done = False success = True while best_r_true_norm > tol: best_sol[r_true_norm < best_r_true_norm] = x[r_true_norm < best_r_true_norm] best_r_true_norm = th.min(r_true_norm, best_r_true_norm) rho[r_true_norm > tol] = 1.0 - (1.0 - rho[r_true_norm > tol]) * 0.25 M_rho = M(rho) if matmul_cnt > 0: x = x0.clone() PPc_x = PPc_x0.clone() r_P_x = r_P_x0.clone() Fr_x = r_P_x - rho * PPc_x res = Fr_x + grad_k res_true = r_P_x - PPc_x + grad_k r_true_norm = res_true.norm(p=1, dim=-1) best_r_true_norm = th.min(r_true_norm, best_r_true_norm) linear_decr = (x * -grad_k).sum(-1) if (best_r_true_norm < tol).all() and (linear_decr > 0).all(): break y = res / M_rho p = -y.clone() ry_old = (res * y).sum(-1, keepdim=True) r_norm = res.norm(p=1, dim=-1) while (r_norm > 0.5 * (1 - beta) * tol)[best_r_true_norm > tol].any(): PPc_p = mml(p) matmul_cnt += 2 Fr_p = (r_P * p) - rho * PPc_p quad = (Fr_p * p).sum(-1, keepdim=True) if (quad <= 0)[best_r_true_norm > tol].any(): warnings.warn( "Warning: negative curvature encountered in CG. Returning best solution. " f"Residual norm less than error: {(best_r_true_norm < err).item()}" ) x = best_sol.clone() done = True success = best_r_true_norm < err warnings.warn("Resetting discount factor rho = 0") rho = th.zeros_like(self.rho) break alpha = ry_old / quad x += alpha * p res += alpha * Fr_p r_norm = res.norm(p=1, dim=-1) if ( th.isnan(r_norm)[best_r_true_norm > tol].any() or th.isinf(r_norm)[best_r_true_norm > tol].any() ): raise ValueError("NaNs or infs encountered in r_norm") PPc_x += alpha * PPc_p r_P_x = r_P * x res_true = r_P_x - PPc_x + grad_k r_true_norm = res_true.norm(p=1, dim=-1) best_sol[r_true_norm < best_r_true_norm] = x[r_true_norm < best_r_true_norm] best_r_true_norm = th.min(r_true_norm, best_r_true_norm) linear_decr = (x * -grad_k).sum(-1) if (best_r_true_norm <= tol).all() and (linear_decr > 0).all(): done = True success = True break if matmul_cnt > 2 * maxIter: warnings.warn("PCG did not converge.") done = True success = False break y = res / M_rho ry_new = (res * y).sum(-1, keepdim=True) p = -y + (ry_new / ry_old) * p ry_old = ry_new.clone() if done: break if r_true_norm <= tol: success = True x = best_sol Pc_x = ((x.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1)).squeeze(-1) matmul_cnt += 1 return x, -Pc_x, matmul_cnt, rho, success