{ "cells": [ { "cell_type": "markdown", "id": "o6hbis23wg9", "metadata": {}, "source": "# Comparing MDOT-TNT with POT Solvers" }, { "cell_type": "markdown", "id": "1jk6hjlhtv7", "source": "This tutorial demonstrates MDOT-TNT on a large random OT problem (n=14,000) and compares its speed and accuracy against the POT library's exact solver (EMD), Sinkhorn, and Greenkhorn.", "metadata": {} }, { "cell_type": "markdown", "id": "92a1a4fd-004f-4500-8758-171446f2d582", "metadata": { "is_executing": true }, "source": [ "First, install the POT package for comparison. https://pythonot.github.io/" ] }, { "cell_type": "code", "execution_count": 1, "id": "9880838b-c88a-491c-b000-c81d539e4421", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: POT in /h/314/kemertas/anaconda3/envs/mdot_tnt/lib/python3.10/site-packages (0.9.4)\n", "Requirement already satisfied: numpy>=1.16 in /h/314/kemertas/anaconda3/envs/mdot_tnt/lib/python3.10/site-packages (from POT) (1.26.4)\n", "Requirement already satisfied: scipy>=1.6 in /h/314/kemertas/anaconda3/envs/mdot_tnt/lib/python3.10/site-packages (from POT) (1.15.1)\n" ] } ], "source": [ "!pip install POT" ] }, { "cell_type": "markdown", "id": "af0df559-f79b-4147-8429-810213ad4fa1", "metadata": {}, "source": [ "Import packages." ] }, { "cell_type": "code", "execution_count": 2, "id": "8013a4e2c8611fc1", "metadata": { "collapsed": false, "is_executing": true, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "import gc\n", "import time\n", "\n", "import ot\n", "import torch as th\n", "\n", "from mdot_tnt import solve_OT\n", "from mdot_tnt.rounding import round_altschuler\n", "\n", "device = \"cuda:0\"" ] }, { "cell_type": "markdown", "id": "d21a5ac2-bf92-4b46-a76d-282d1ed14cef", "metadata": { "collapsed": false, "is_executing": true, "jupyter": { "outputs_hidden": false } }, "source": [ "Add a function for sampling random OT problems." ] }, { "cell_type": "code", "execution_count": 3, "id": "edcb3503-b436-4fc7-8ee4-db5270238df3", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "def sample_random_problem(N, M, dim=100):\n", " # Sample some distributions r and c according to a Dirichlet distribution.\n", " r = th.distributions.Dirichlet(th.ones(N)).sample()\n", " c = th.distributions.Dirichlet(th.ones(M)).sample()\n", "\n", " # Sample N points x and M points y from a multivariate normal distribution in 100D.\n", " x = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((N,))\n", " y = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((M,))\n", "\n", " # Compute the cost matrix C = ||x - y||_2^2.\n", " C = th.cdist(x, y, p=2) ** 2\n", " C /= C.max()\n", "\n", " # Change to double precision.\n", " # r, c, C = r.double(), c.double(), C.double()\n", " r /= r.sum()\n", " c /= c.sum()\n", "\n", " return r, c, C" ] }, { "cell_type": "markdown", "id": "55b72848-fa9e-4533-8280-e1e41df4f0be", "metadata": {}, "source": [ "Sample a large problem (n=14000)" ] }, { "cell_type": "code", "execution_count": 6, "id": "3effc199-180a-4c13-8c66-4d7b295066d1", "metadata": {}, "outputs": [], "source": [ "n = 14000\n", "r, c, C = sample_random_problem(n, n)\n", "r, c, C = r.double().to(device), c.double().to(device), C.double().to(device)" ] }, { "cell_type": "markdown", "id": "7d18fa77-80b8-4409-ac3a-9be5a321def0", "metadata": {}, "source": [ "Let's use the highly efficient exact solver (CPU-based) of the POT library and time it." ] }, { "cell_type": "code", "execution_count": 7, "id": "a43f565d-750e-4b7b-b660-f0a540466527", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OT Cost: 0.3182945673, Time: 97.668\n" ] } ], "source": [ "r_, c_, C_ = r.cpu().numpy(), c.cpu().numpy(), C.cpu().numpy()\n", "time_start = time.time()\n", "cost_emd = ot.emd2(r_, c_, C_, numItermax=int(1e10))\n", "elapsed = time.time() - time_start\n", "print(f\"OT Cost: {cost_emd:.10f}, Time: {elapsed:.3f}\")" ] }, { "cell_type": "markdown", "id": "3032e695-9b75-4b32-a7d2-1a93c33525ed", "metadata": {}, "source": [ "Now use mdot_tnt to tackle the problem on a GPU (NVIDIA RTX 2080 Ti in this case)." ] }, { "cell_type": "code", "execution_count": 8, "id": "c172ae99-d6a2-48e9-9cef-02deeece4779", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MDOT-TNT error: 8.803e-05, Time: 4.435\n" ] } ], "source": [ "time_start = time.time()\n", "cost = solve_OT(\n", " r, c, C, gamma_f=1000\n", ") # gamma_f is the inverse of the final regularization weight (1e-3 here)\n", "elapsed = time.time() - time_start\n", "print(f\"MDOT-TNT error: {cost - cost_emd:.3e}, Time: {elapsed:.3f}\")\n", "gc.collect()\n", "th.cuda.empty_cache()" ] }, { "cell_type": "markdown", "id": "a2c09152-708a-4d1b-832d-c6fe89a847d1", "metadata": { "scrolled": true }, "source": [ "4-5 decimal precision with more than 20x speedup. Needless to say, the speedup can be better on higher-end GPUs.\n", "Let's also check the speedup using FP32 precision." ] }, { "cell_type": "code", "execution_count": 9, "id": "8e17d26d-e80f-4c60-ace8-8a1cf1bfcb77", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MDOT-TNT error: 8.821e-05, Time: 1.705\n" ] } ], "source": [ "time_start = time.time()\n", "cost = solve_OT(\n", " r.float(), c.float(), C.float(), gamma_f=1000\n", ") # gamma_f is the inverse of the final regularization weight (1e-3 here)\n", "elapsed = time.time() - time_start\n", "print(f\"MDOT-TNT error: {cost - cost_emd:.3e}, Time: {elapsed:.3f}\")\n", "gc.collect()\n", "th.cuda.empty_cache()" ] }, { "cell_type": "markdown", "id": "bd6e0927-bbb2-4578-8149-577c2904024c", "metadata": { "scrolled": true }, "source": [ "57x speedup on this random problem! Not bad!" ] }, { "cell_type": "markdown", "id": "cea5928e-2e39-49e3-924c-497c4f1b4333", "metadata": { "scrolled": true }, "source": [ "If either marginal is known to have many tiny entries (is effectively a sparse vector), we can further accelerate computation by dropping those particles by setting `drop_tiny=True`. Note that this feature was not used in the paper for fairness in benchmarking, but can be useful in practice." ] }, { "cell_type": "code", "execution_count": 10, "id": "7d805ae1-e256-4afc-aa81-c8b7b366f8d1", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Set a random half of the entries of r and c to 1e-20, and renormalize.\n", "r2 = r.clone()\n", "c2 = c.clone()\n", "r2[th.randperm(n)[: n // 2]] = 1e-20\n", "c2[th.randperm(n)[: n // 2]] = 1e-20\n", "r2 /= r2.sum()\n", "c2 /= c2.sum()" ] }, { "cell_type": "code", "execution_count": 11, "id": "bd3fc998-3436-4eec-abb4-2645f621642a", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OT Cost: 0.3182945673, Time: 82.562\n" ] } ], "source": [ "time_start = time.time()\n", "cost_emd2 = ot.emd2(r2.cpu().numpy(), c2.cpu().numpy(), C.cpu().numpy(), numItermax=int(1e10))\n", "elapsed = time.time() - time_start\n", "print(f\"OT Cost: {cost_emd:.10f}, Time: {elapsed:.3f}\")" ] }, { "cell_type": "markdown", "id": "fd2f1075-1760-441f-8c6e-928c7b853066", "metadata": { "scrolled": true }, "source": [ "A similar runtime as before for the exact solver... Let's rerun MDOT-TNT with `drop_tiny=True`." ] }, { "cell_type": "code", "execution_count": 12, "id": "c061f3d4-0a2e-4b43-bf52-19fa971c6e1b", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dropped 7028 entries from r and 7032 entries from c.\n", "MDOT-TNT error: 8.172e-05, Time: 1.155\n" ] } ], "source": [ "time_start = time.time()\n", "cost = solve_OT(\n", " r2, c2, C, gamma_f=1000, drop_tiny=True\n", ") # gamma_f is the inverse of the final regularization weight (1e-3 here)\n", "elapsed = time.time() - time_start\n", "print(f\"MDOT-TNT error: {cost - cost_emd2:.3e}, Time: {elapsed:.3f}\")\n", "gc.collect()\n", "th.cuda.empty_cache()" ] }, { "cell_type": "markdown", "id": "6efc1bce-8f4b-4b20-83d3-c22b0a02321b", "metadata": { "scrolled": true }, "source": [ "Same level of precision as before, but this time ~70x speedup. And now doing the same with FP32 precision." ] }, { "cell_type": "code", "execution_count": 13, "id": "d212ae9e-fa47-4ddc-9685-d22e33025e7a", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dropped 7028 entries from r and 7032 entries from c.\n", "MDOT-TNT error: 8.187e-05, Time: 0.535\n" ] } ], "source": [ "time_start = time.time()\n", "cost = solve_OT(\n", " r2.float(), c2.float(), C.float(), gamma_f=1000, drop_tiny=True\n", ") # gamma_f is the inverse of the final regularization weight (1e-3 here)\n", "elapsed = time.time() - time_start\n", "print(f\"MDOT-TNT error: {cost - cost_emd2:.3e}, Time: {elapsed:.3f}\")\n", "gc.collect()\n", "th.cuda.empty_cache()" ] }, { "cell_type": "markdown", "id": "dd7a2b17-62bc-4703-bd75-35c600fcc22a", "metadata": { "scrolled": true }, "source": [ "154x speedup. Let's go back to the original problem (dense marginals) and see how Sinkhorn fares; starting with strong regularization and gradually decreasing regularization weight." ] }, { "cell_type": "code", "execution_count": 14, "id": "cc92ef38-2ff5-4161-b692-4a9cb97a140e", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sinkhorn error: 1.458e-02, Time: 0.511\n" ] } ], "source": [ "gc.collect()\n", "th.cuda.empty_cache()\n", "time_start = time.time()\n", "plan = ot.sinkhorn(r, c, C, reg=1 / 100, numItermax=int(1e10))\n", "plan = round_altschuler(plan, r, c)\n", "cost = (plan * C).sum()\n", "elapsed = time.time() - time_start\n", "print(f\"Sinkhorn error: {cost - cost_emd:.3e}, Time: {elapsed:.3f}\")\n", "del plan" ] }, { "cell_type": "markdown", "id": "541df3c4-72a2-4976-a437-ecfb01ff45f0", "metadata": { "scrolled": true }, "source": [ "Remember the optimal cost is about 0.318. Relative error here is about 0.0146 * 100 / 0.318 = 4.6% (hardly negligible). Let's run at the same temperature as MDOT-TNT." ] }, { "cell_type": "code", "execution_count": 15, "id": "a2fa6905-a32d-4388-b8f7-e9c70f95697d", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sinkhorn error: 7.315e-05, Time: 67.102\n" ] } ], "source": [ "time_start = time.time()\n", "plan = ot.sinkhorn(r, c, C, reg=1 / 1000, numItermax=int(1e10))\n", "plan = round_altschuler(plan, r, c)\n", "cost = (plan * C).sum()\n", "elapsed = time.time() - time_start\n", "print(f\"Sinkhorn error: {cost - cost_emd:.3e}, Time: {elapsed:.3f}\")\n", "del plan" ] }, { "cell_type": "markdown", "id": "98dd8cde-b3b5-4df4-8222-a2337d934079", "metadata": { "scrolled": true }, "source": [ "MDOT-TNT exhibits 15x speedup (took 4.435 seconds under the same setup of dense vectors + FP64 precision). As we show in the paper, the gap grows with weaker regularization." ] }, { "cell_type": "markdown", "id": "199549be-21a8-4d1f-add2-5db753a22c96", "metadata": {}, "source": [ "Let's also give Greenkhorn by Altschuler et al. (2017) a try." ] }, { "cell_type": "code", "execution_count": 17, "id": "b591e6b4-eefa-49b5-ad36-88ebc1798e5f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Greenkhorn error: 7.461e-05, Time: 2929.723\n" ] } ], "source": [ "gc.collect()\n", "th.cuda.empty_cache()\n", "time_start = time.time()\n", "plan = ot.bregman.greenkhorn(r, c, C, reg=1 / 1000, numItermax=int(1e10))\n", "plan = round_altschuler(plan, r, c)\n", "cost = (plan * C).sum()\n", "elapsed = time.time() - time_start\n", "print(f\"Greenkhorn error: {cost - cost_emd:.3e}, Time: {elapsed:.3f}\")\n", "del plan" ] }, { "cell_type": "markdown", "id": "9688fc58-c071-480e-ad11-d74c806ed5cc", "metadata": {}, "source": [ "For this value of n=14000, Greenkhorn suffers from low GPU utilization. Even if the total number of row or column updates are fewer than those of Sinkhorn, in practice, it is substantially slower because of limited parallelization, updating one row/column at a time." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }