API Reference

Main Functions

mdot_tnt.solve_OT(r, c, C, gamma_f=1024.0, drop_tiny=False, return_plan=False, round=True, log=False, devices=None, num_gpus=None)[source]

Solve the entropic-regularized optimal transport problem.

Inputs r, c, C are required to be torch tensors.

Parameters:
  • r – Row marginal of shape (n,). Must sum to 1.

  • c – Column marginal of shape (m,). Must sum to 1.

  • C – Cost matrix of shape (n, m). Recommended to scale entries to [0, 1].

  • gamma_f – Temperature (inverse of the regularization weight). For many problems, stable up to 2^18. Higher values return more accurate solutions but take longer to converge. Use double precision if gamma_f is large.

  • drop_tiny – If either marginal is known to be sparse, set this to True to drop tiny entries for a speedup. If return_plan is True, the returned plan will be in the original dimensions.

  • return_plan – If True, return the optimal transport plan rather than the cost.

  • round – If True, use the rounding algorithm of Altschuler et al. (2017) to (a) return a feasible plan if return_plan is True and (b) the cost of the rounded plan if return_plan is False.

  • log – If True, additionally return a dictionary containing logs of the optimization process.

  • devices – List of CUDA devices (int, str, or torch.device) to distribute column blocks of C across. Enables multi-GPU solving of a single OT problem by splitting the cost matrix along the column dimension and running blocked operations in parallel. Pass [0, 1] or [torch.device('cuda:0'), torch.device('cuda:1')]. Mutually exclusive with num_gpus. Ignored when only one device is given.

  • num_gpus – Number of CUDA devices to use (starting from index 0). Mutually exclusive with devices.

Returns:

Transport cost (scalar tensor) if return_plan is False, or the transport plan of shape (n, m) if return_plan is True. If log is True, returns a tuple of (result, logs_dict).

mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024.0, drop_tiny=False, return_plan=False, round=True, log=False, devices=None, num_gpus=None)[source]

Solve multiple entropic-regularized optimal transport problems in a single batched call.

This function provides significant speedup (5-10x) over solving problems sequentially by amortizing GPU synchronization overhead across all problems in the batch.

Parameters:
  • r (Tensor) – (batch, n) row marginals. Each row must sum to 1.

  • c (Tensor) – (batch, m) column marginals. Each row must sum to 1.

  • C (Tensor) – Cost matrix. Either (n, m) for shared cost across all problems, or (batch, n, m) for per-problem costs. Recommended to scale to [0, 1].

  • gamma_f (float) – Temperature (inverse of regularization weight). Higher values give more accurate solutions but take longer. Stable up to ~2^18 with float64.

  • drop_tiny (bool) – Not supported in batched solver. Raises NotImplementedError if True.

  • return_plan (bool) – If True, return transport plans instead of costs.

  • round (bool) – If True, apply Altschuler rounding for feasible solutions.

  • log (bool) – If True, also return optimization logs.

  • devices (Optional[Any]) – List of CUDA devices (int, str, or torch.device) to distribute the batch across. The batch is split evenly; each GPU solves its share independently. Pass [0, 1] or [torch.device('cuda:0'), torch.device('cuda:1')]. Mutually exclusive with num_gpus. Ignored when only one device is given.

  • num_gpus (Optional[int]) – Number of CUDA devices to use (starting from index 0). Mutually exclusive with devices.

Returns:

(batch,) tensor of transport costs. If return_plan is True: (batch, n, m) tensor of transport plans. If log is True: tuple of (result, logs_dict).

Return type:

If return_plan is False

Example

>>> # Solve 32 OT problems of size 512×512
>>> r = torch.rand(32, 512, device='cuda', dtype=torch.float64)
>>> r = r / r.sum(-1, keepdim=True)
>>> c = torch.rand(32, 512, device='cuda', dtype=torch.float64)
>>> c = c / c.sum(-1, keepdim=True)
>>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64)
>>> costs = solve_OT_batched(r, c, C, gamma_f=1024.)
mdot_tnt.lowmem.solve_OT_lowmem(r, c, C=None, X=None, Y=None, cost_fn=None, gamma_f=1024.0, block_size=None, drop_tiny=False, return_plan=False, round=True, log=False, devices=None, num_gpus=None)[source]

Solve entropic-regularized OT with O(nk) working memory.

Accepts either a dense cost matrix or point clouds with a cost function. In point-cloud mode the cost matrix is never materialised, giving true O(nk + (n+m)d) memory.

Parameters:
  • r (Tensor) – Row marginal, shape (n,). Must sum to 1.

  • c (Tensor) – Column marginal, shape (m,). Must sum to 1.

  • C (Optional[Tensor]) – Dense cost matrix, shape (n, m). Mutually exclusive with X/Y. Recommended to scale entries to [0, 1].

  • X (Optional[Tensor]) – Source points, shape (n, d). Use with Y and optionally cost_fn.

  • Y (Optional[Tensor]) – Target points, shape (m, d). Use with X.

  • cost_fn (Optional[Callable[[Tensor, Tensor], Tensor]]) – (X, Y_block) -> C_block of shape (n, k). Defaults to squared_euclidean. Only used when X/Y are provided.

  • gamma_f (float) – Temperature (inverse regularization weight).

  • block_size (Optional[int]) – Columns per block. Controls memory / compute trade-off: - None or m: fastest, O(nm) or O(nk) memory - sqrt(m): good trade-off - 1: minimum memory (slowest)

  • drop_tiny (bool) – Drop tiny marginal entries for speedup with sparse marginals.

  • return_plan (bool) – If True, return the full transport plan. Note: the plan is inherently O(nm); in point-cloud mode it is built block-by-block so that C and P are never both in memory simultaneously.

  • round (bool) – Use Altschuler rounding for feasible plans / costs.

  • log (bool) – Additionally return optimization logs.

  • devices (Optional[Any]) – List of CUDA devices (int, str, or torch.device) to distribute column blocks across. Pass [0, 1] or [torch.device('cuda:0'), torch.device('cuda:1')]. Mutually exclusive with num_gpus. When a single device is given the standard single-GPU solver is used.

  • num_gpus (Optional[int]) – Number of CUDA devices to use (starting from index 0). Mutually exclusive with devices.

Returns:

Transport cost (scalar) if return_plan is False, or the transport plan (n, m) if return_plan is True. If log is True, returns (result, logs).

Core Modules

mdot

Core MDOT solver using truncated Newton projection.

mdot_tnt.mdot.preprocess_marginals(r, c, C, eps)[source]

Drop the smallest marginal entries whose cumulative sum is below a threshold.

Parameters:
  • r (Tensor) – The row marginal of shape (n,).

  • c (Tensor) – The column marginal of shape (m,).

  • C (Tensor) – The cost matrix of shape (n, m).

  • eps (float) – The threshold for the cumulative sum of the marginal entries to be dropped.

Returns:

  • (r_new, r_keep): The new row marginal and indices of kept entries.

  • (c_new, c_keep): The new column marginal and indices of kept entries.

  • C: The cost matrix with corresponding rows and columns dropped.

Return type:

A tuple containing

mdot_tnt.mdot.smooth_marginals(r, c, eps, w_r=0.5, w_c=0.5)[source]

Smooth the marginals by adding a small amount of uniform mass to each entry.

Parameters:
  • r (Tensor) – The row marginal of shape (n,).

  • c (Tensor) – The column marginal of shape (m,).

  • eps (Tensor) – The amount of mass to add to each entry.

  • w_r (float) – The weight for the row marginal.

  • w_c (float) – The weight for the column marginal.

Return type:

Tuple[Tensor, Tensor]

Returns:

A tuple (r_hat, c_hat) of smoothed marginals with total TV distance at most eps from the original marginals.

mdot_tnt.mdot.adjust_schedule(q, deltas=None)[source]

Adjust the temperature annealing schedule based on the success of the Truncated Newton method.

Parameters:
  • q (float) – The current temperature annealing schedule adjustment factor.

  • deltas (Optional[List[float]]) – The list of deltas from the Truncated Newton method; see Sec. 3.3 of Kemertas et al. (2025).

Return type:

float

Returns:

The new temperature annealing schedule adjustment factor.

mdot_tnt.mdot.mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0)[source]

Solve the entropic-regularized optimal transport problem using the MDOT method.

This implements the MDOT method introduced in the paper: “Efficient and Accurate Optimal Transport with Mirror Descent and Conjugate Gradients” by Mete Kemertas, Allan D. Jepson and Amir-massoud Farahmand. URL: https://arxiv.org/abs/2307.08507

Here, we use the Truncated Newton method for projection.

Parameters:
  • r (Tensor) – The first marginal of shape (n,).

  • c (Tensor) – The second marginal of shape (m,).

  • C (Tensor) – The cost matrix of shape (n, m). Recommended to scale entries to [0, 1].

  • gamma_f (float) – The final temperature (inverse of the regularization weight).

  • gamma_i (float) – The initial temperature.

  • p (float) – The exponent for the epsilon function, used to determine the stopping criterion for the dual gradient.

  • q (float) – The temperature annealing (or mirror descent step size) schedule adjustment factor.

Returns:

  • u: The row dual variables of shape (n,).

  • v: The column dual variables of shape (m,).

  • gamma: The final temperature achieved.

  • k_total: The total number of O(n^2) primitive operations.

  • logs: Dictionary with optimization statistics.

Return type:

A tuple containing

batched

Batched MDOT-TNT solver for solving multiple optimal transport problems simultaneously.

This module provides batched versions of the MDOT-TNT solver that achieve significant speedups (5-10x) over sequential solving by amortizing GPU synchronization overhead across all problems in a batch.

Key insight: The main solver has many Python while-loops that check convergence, each requiring a GPU→CPU sync. By batching N problems together, we do one sync per iteration for the entire batch instead of N syncs.

Supports: - Multiple marginal pairs with shared cost matrix: r, c shape (batch, n), C shape (n, m) - Multiple OT problems with different costs: r, c shape (batch, n), C shape (batch, n, m)

Example usage:
>>> import torch
>>> from mdot_tnt.batched import solve_OT_batched
>>>
>>> # 32 problems, each 512-dimensional
>>> r = torch.rand(32, 512, device='cuda', dtype=torch.float64)
>>> r = r / r.sum(dim=-1, keepdim=True)
>>> c = torch.rand(32, 512, device='cuda', dtype=torch.float64)
>>> c = c / c.sum(dim=-1, keepdim=True)
>>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64)  # Shared cost
>>>
>>> costs = solve_OT_batched(r, c, C, gamma_f=1024.)
>>> print(costs.shape)  # (32,)
class mdot_tnt.batched.BatchedTruncatedNewtonProjector(device, dtype, **kwargs)[source]

Batched Truncated Newton projector for the MDOT algorithm.

Projects onto the set of couplings satisfying marginal constraints, processing multiple problems simultaneously for efficiency.

Parameters:
project(gamma_C, log_r, log_c, eps_d, u, v, active_mask=None)[source]

Project onto the constraint set for all problems in the batch.

Parameters:
  • gamma_C (Tensor) – (batch, n, m) or (n, m) cost matrix scaled by gamma.

  • log_r (Tensor) – (batch, n) log of row marginals.

  • log_c (Tensor) – (batch, m) log of column marginals.

  • eps_d (Union[float, Tensor]) – Convergence tolerance, scalar or (batch,) tensor.

  • u (Tensor) – (batch, n) initial row dual variables.

  • v (Tensor) – (batch, m) initial column dual variables.

  • active_mask (Optional[Tensor]) – (batch,) bool tensor, True for problems to process.

Returns:

(batch, n) updated row dual variables. v: (batch, m) updated column dual variables. logs: Dictionary with optimization statistics. success: (batch,) bool tensor indicating convergence per problem.

Return type:

u

mdot_tnt.batched.solve_OT_batched(r, c, C, gamma_f=1024.0, drop_tiny=False, return_plan=False, round=True, log=False, devices=None, num_gpus=None)[source]

Solve multiple entropic-regularized optimal transport problems in a single batched call.

This function provides significant speedup (5-10x) over solving problems sequentially by amortizing GPU synchronization overhead across all problems in the batch.

Parameters:
  • r (Tensor) – (batch, n) row marginals. Each row must sum to 1.

  • c (Tensor) – (batch, m) column marginals. Each row must sum to 1.

  • C (Tensor) – Cost matrix. Either (n, m) for shared cost across all problems, or (batch, n, m) for per-problem costs. Recommended to scale to [0, 1].

  • gamma_f (float) – Temperature (inverse of regularization weight). Higher values give more accurate solutions but take longer. Stable up to ~2^18 with float64.

  • drop_tiny (bool) – Not supported in batched solver. Raises NotImplementedError if True.

  • return_plan (bool) – If True, return transport plans instead of costs.

  • round (bool) – If True, apply Altschuler rounding for feasible solutions.

  • log (bool) – If True, also return optimization logs.

  • devices (Optional[Any]) – List of CUDA devices (int, str, or torch.device) to distribute the batch across. The batch is split evenly; each GPU solves its share independently. Pass [0, 1] or [torch.device('cuda:0'), torch.device('cuda:1')]. Mutually exclusive with num_gpus. Ignored when only one device is given.

  • num_gpus (Optional[int]) – Number of CUDA devices to use (starting from index 0). Mutually exclusive with devices.

Returns:

(batch,) tensor of transport costs. If return_plan is True: (batch, n, m) tensor of transport plans. If log is True: tuple of (result, logs_dict).

Return type:

If return_plan is False

Example

>>> # Solve 32 OT problems of size 512×512
>>> r = torch.rand(32, 512, device='cuda', dtype=torch.float64)
>>> r = r / r.sum(-1, keepdim=True)
>>> c = torch.rand(32, 512, device='cuda', dtype=torch.float64)
>>> c = c / c.sum(-1, keepdim=True)
>>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64)
>>> costs = solve_OT_batched(r, c, C, gamma_f=1024.)

truncated_newton

Truncated Newton projector for the MDOT algorithm.

class mdot_tnt.truncated_newton.TruncatedNewtonProjector(device, dtype, **kwargs)[source]

Truncated Newton projector for the MDOT algorithm.

Projects onto the set of couplings satisfying marginal constraints using a preconditioned conjugate gradient method within a Newton framework.

Parameters:
project(gamma_C, log_r, log_c, eps_d, u, v)[source]

Project onto the set of couplings that satisfy the marginal constraints.

Parameters:
  • gamma_C (Tensor) – The cost matrix scaled by gamma, shape (n, m).

  • log_r (Tensor) – Log of row marginals, shape (n,).

  • log_c (Tensor) – Log of column marginals, shape (m,).

  • eps_d (Union[float, Tensor]) – Convergence tolerance for the dual gradient norm.

  • u (Tensor) – Initial row dual variables, shape (n,).

  • v (Tensor) – Initial column dual variables, shape (m,).

Returns:

Updated row dual variables. v: Updated column dual variables. logs: Dictionary with optimization statistics. success: Whether projection converged successfully.

Return type:

u

chi_sinkhorn(u, v, log_r, log_c, log_r_P, eps_chi, maxOps=inf)[source]
Return type:

Tuple[Tensor, Tensor, Tensor, Tensor, int]

Parameters:
newton_solve(P, c, diag_PPc, grad_k, r_P, err, beta=0.5, eta_k=0.5, maxIter=500)[source]
Return type:

Tuple[Tensor, Tensor, int, Tensor, bool]

Parameters:

rounding

This file contains the implementation of the rounding algorithm proposed by Altschuler et al. (2017) in the paper “Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration”. The algorithm is used to round the transport plan obtained from the Sinkhorn algorithm to a feasible transport plan in the set U(r, c), where r and c are the row and column marginals, respectively. The algorithm is used in the mdot.py file to round the transport plan and compute the cost of the rounded plan. The implementation is based on the original paper.

mdot_tnt.rounding.round_altschuler(P, r, c)[source]

Performs rounding given a transport plan and marginals.

Parameters:
  • P (Tensor) – The input transport plan of shape (n, m).

  • r (Tensor) – Row marginal of shape (n,).

  • c (Tensor) – Column marginal of shape (m,).

Return type:

Tensor

Returns:

Rounded transport plan in feasible set U(r, c).

mdot_tnt.rounding.rounded_cost_altschuler(u, v, r, c, C, gamma)[source]

Performs rounding and cost computation in log-domain given dual variables.

This function computes the transport cost without storing the full n×m transport plan, making it memory efficient.

Parameters:
  • u (Tensor) – Dual variable for rows of shape (n,).

  • v (Tensor) – Dual variable for columns of shape (m,).

  • r (Tensor) – Row marginal of shape (n,).

  • c (Tensor) – Column marginal of shape (m,).

  • C (Tensor) – Cost matrix of shape (n, m).

  • gamma (Union[float, Tensor]) – Temperature (inverse of the entropic regularization weight).

Return type:

Tensor

Returns:

The optimal transport cost as a scalar tensor.

lowmem

Low-memory MDOT-TNT solver using column-batched materialization.

Achieves O(nk + (n+m)d) working memory by:
  1. Never materializing the full n x m cost matrix C — cost blocks C[:, s:e] are computed on-the-fly from point clouds (X, Y) and a cost function.

  2. Never materializing the full n x m transport plan P — plan blocks are computed on-the-fly from dual variables and cost blocks.

The user controls the trade-off via block_size (k):
  • k = m: O(nm) memory, fastest (same as standard algorithm)

  • k = sqrt(m): O(n*sqrt(m)) memory (good trade-off)

  • k = 1: O(n) memory, slowest (like element-wise PyKeOps)

Supports two input modes:
  • Point cloud mode: X (n, d), Y (m, d), cost_fn — truly O(nk + (n+m)d)

  • Dense matrix mode: C (n, m) — O(nm) for C storage, O(nk) working memory

Based on:

“A Truncated Newton Method for Optimal Transport” by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).

mdot_tnt.lowmem.squared_euclidean(X, Y_block)[source]

Squared Euclidean cost: C[i,j] = ||X[i] - Y[j]||^2.

Parameters:
  • X (Tensor) – Source points, shape (n, d).

  • Y_block (Tensor) – Target point block, shape (k, d).

Return type:

Tensor

Returns:

Cost block of shape (n, k).

mdot_tnt.lowmem.euclidean(X, Y_block)[source]

Euclidean cost: C[i,j] = ||X[i] - Y[j]||.

Parameters:
  • X (Tensor) – Source points, shape (n, d).

  • Y_block (Tensor) – Target point block, shape (k, d).

Return type:

Tensor

Returns:

Cost block of shape (n, k).

class mdot_tnt.lowmem.LowMemoryTruncatedNewtonProjector(device, dtype, block_size, cost_block_fn, **kwargs)[source]

Low-memory variant of the Truncated Newton projector.

Instead of storing the full cost matrix or transport plan, cost blocks are computed on-the-fly via a user-supplied cost_block_fn. All LSE reductions, Hessian-vector products, and diagonal preconditioner computations are performed block-by-block with O(nk) working memory.

Parameters:
  • device (device) – PyTorch device for computations.

  • dtype (dtype) – Data type for tensors.

  • block_size (int) – Number of columns to process at once.

  • cost_block_fn (Callable[[int, int], Tensor]) – (start, end) -> C[:, start:end] of shape (n, end-start).

  • **kwargs (Any) – Additional options (debug: bool for verbose output).

project(gamma, log_r, log_c, eps_d, u, v)[source]

Project onto the set of couplings satisfying marginal constraints.

Parameters:
  • gamma (Union[float, Tensor]) – Temperature (inverse regularization weight) for this projection step.

  • log_r (Tensor) – Log of row marginals, shape (n,).

  • log_c (Tensor) – Log of column marginals, shape (m,).

  • eps_d (Union[float, Tensor]) – Convergence tolerance for the dual gradient norm.

  • u (Tensor) – Initial row dual variables, shape (n,).

  • v (Tensor) – Initial column dual variables, shape (m,).

Returns:

Updated row dual variables. v: Updated column dual variables. logs: Dictionary with optimization statistics. success: Whether projection converged successfully.

Return type:

u

chi_sinkhorn(u, v, log_r, log_c, log_r_P, eps_chi, maxOps=inf)[source]

Chi-squared Sinkhorn iterations using blocked LSE.

Return type:

Tuple[Tensor, Tensor, Tensor, Tensor, int]

Parameters:
newton_solve(u, v, c, diag_PPc, grad_k, r_P, err, beta=0.5, eta_k=0.5, maxIter=500)[source]

Newton solve using blocked Hessian-vector products.

Transport plan blocks are recomputed on-the-fly for each mat-vec product from the (unchanged) dual variables u, v and the cost function.

Return type:

Tuple[Tensor, Tensor, int, Tensor, bool]

Parameters:
mdot_tnt.lowmem.blocked_rounded_cost(u, v, r, c, cost_block_fn, m, gamma, block_size)[source]

Compute the rounded transport cost without materializing C or P.

Mirrors rounded_cost_altschuler but accesses the cost matrix only through cost_block_fn. Working memory is O(n * block_size).

Parameters:
  • u (Tensor) – Dual variable for rows, shape (n,).

  • v (Tensor) – Dual variable for columns, shape (m,).

  • r (Tensor) – Row marginal, shape (n,).

  • c (Tensor) – Column marginal, shape (m,).

  • cost_block_fn (Callable[[int, int], Tensor]) – (s, e) -> C[:, s:e].

  • m (int) – Number of columns.

  • gamma (Union[float, Tensor]) – Temperature.

  • block_size (int) – Number of columns per block.

Return type:

Tensor

Returns:

Transport cost (scalar tensor).

mdot_tnt.lowmem.blocked_transport_cost(u, v, cost_block_fn, m, gamma, block_size)[source]

Compute unrounded transport cost without materializing C or P.

cost = sum_{i,j} exp(u_i + v_j - gamma * C_{ij}) * C_{ij}

Parameters:
  • u (Tensor) – Dual variable for rows, shape (n,).

  • v (Tensor) – Dual variable for columns, shape (m,).

  • cost_block_fn (Callable[[int, int], Tensor]) – (s, e) -> C[:, s:e].

  • m (int) – Number of columns.

  • gamma (Union[float, Tensor]) – Temperature.

  • block_size (int) – Number of columns per block.

Return type:

Tensor

Returns:

Transport cost (scalar tensor).

mdot_tnt.lowmem.mdot_lowmem(r, c, cost_block_fn, gamma_f, block_size, gamma_i=16, p=1.5, q=2.0, devices=None)[source]

Solve entropic-regularized OT using MDOT with low-memory projection.

Parameters:
  • r (Tensor) – Row marginal, shape (n,).

  • c (Tensor) – Column marginal, shape (m,).

  • cost_block_fn (Callable[[int, int], Tensor]) – (s, e) -> C[:, s:e] — computes cost blocks lazily.

  • gamma_f (float) – Final temperature (inverse regularization weight).

  • block_size (int) – Number of columns per block for the projector.

  • gamma_i (float) – Initial temperature.

  • p (float) – Exponent for the epsilon schedule.

  • q (float) – Temperature annealing factor.

  • devices (Optional[Any]) – Optional list of torch.device objects for multi-GPU execution (from mdot_tnt.multigpu._resolve_devices()). When provided, column blocks are distributed across these devices and blocked primitives run in parallel.

Returns:

Row dual variables, shape (n,). v: Column dual variables, shape (m,). gamma: Final temperature achieved. k_total: Total number of O(n^2) primitive operations. logs: Optimization statistics.

Return type:

u

mdot_tnt.lowmem.solve_OT_lowmem(r, c, C=None, X=None, Y=None, cost_fn=None, gamma_f=1024.0, block_size=None, drop_tiny=False, return_plan=False, round=True, log=False, devices=None, num_gpus=None)[source]

Solve entropic-regularized OT with O(nk) working memory.

Accepts either a dense cost matrix or point clouds with a cost function. In point-cloud mode the cost matrix is never materialised, giving true O(nk + (n+m)d) memory.

Parameters:
  • r (Tensor) – Row marginal, shape (n,). Must sum to 1.

  • c (Tensor) – Column marginal, shape (m,). Must sum to 1.

  • C (Optional[Tensor]) – Dense cost matrix, shape (n, m). Mutually exclusive with X/Y. Recommended to scale entries to [0, 1].

  • X (Optional[Tensor]) – Source points, shape (n, d). Use with Y and optionally cost_fn.

  • Y (Optional[Tensor]) – Target points, shape (m, d). Use with X.

  • cost_fn (Optional[Callable[[Tensor, Tensor], Tensor]]) – (X, Y_block) -> C_block of shape (n, k). Defaults to squared_euclidean. Only used when X/Y are provided.

  • gamma_f (float) – Temperature (inverse regularization weight).

  • block_size (Optional[int]) – Columns per block. Controls memory / compute trade-off: - None or m: fastest, O(nm) or O(nk) memory - sqrt(m): good trade-off - 1: minimum memory (slowest)

  • drop_tiny (bool) – Drop tiny marginal entries for speedup with sparse marginals.

  • return_plan (bool) – If True, return the full transport plan. Note: the plan is inherently O(nm); in point-cloud mode it is built block-by-block so that C and P are never both in memory simultaneously.

  • round (bool) – Use Altschuler rounding for feasible plans / costs.

  • log (bool) – Additionally return optimization logs.

  • devices (Optional[Any]) – List of CUDA devices (int, str, or torch.device) to distribute column blocks across. Pass [0, 1] or [torch.device('cuda:0'), torch.device('cuda:1')]. Mutually exclusive with num_gpus. When a single device is given the standard single-GPU solver is used.

  • num_gpus (Optional[int]) – Number of CUDA devices to use (starting from index 0). Mutually exclusive with devices.

Returns:

Transport cost (scalar) if return_plan is False, or the transport plan (n, m) if return_plan is True. If log is True, returns (result, logs).