Skip to content

Instantly share code, notes, and snippets.

@hvaara
Created May 27, 2025 17:01
Show Gist options
  • Select an option

  • Save hvaara/34afb8fb3fc1422c319a5fd972f8fc3a to your computer and use it in GitHub Desktop.

Select an option

Save hvaara/34afb8fb3fc1422c319a5fd972f8fc3a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m2024-08-17T00:13:28.454221Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mtorch.__version__ = '2.5.0a0+git648fc6c'\u001b[0m\n"
]
}
],
"source": [
"import time\n",
"import torch\n",
"import logging\n",
"import datetime\n",
"import structlog\n",
"import os\n",
"\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"OPERATION_INFO = (1 << 0)\n",
"COPY_INFO = (1 << 1)\n",
"CPU_FALLBACK_INFO = (1 << 2)\n",
"\n",
"ALL_STATS = (1 << 3)\n",
"OPERATION_STATS = (1 << 4)\n",
"COPY_STATS = (1 << 5)\n",
"CPU_FALLBACK_STATS = (1 << 6)\n",
"\n",
"INCLUDE_GPU_TIME = (1 << 7)\n",
"INCLUDE_KERNEL_TIME = (1 << 8)\n",
"INCLUDE_BUFFER_ID = (1 << 9)\n",
"\n",
"LOG_COUNT = (INCLUDE_BUFFER_ID << 1) - 1\n",
"\n",
"\n",
"ALL_FLAGS = OPERATION_INFO+COPY_INFO+CPU_FALLBACK_INFO+INCLUDE_GPU_TIME+INCLUDE_KERNEL_TIME+INCLUDE_BUFFER_ID\n",
"\n",
"# os.environ[\"PYTORCH_MPS_LOG_PROFILE_INFO\"] = str(ALL_FLAGS)\n",
"# os.environ[\"PYTORCH_DEBUG_MPS_ALLOCATOR\"] = \"1\"\n",
"\n",
"DEVICE = \"mps\"\n",
"\n",
"MEMLEAK_DETECTED = 1\n",
"NO_MEMLEAK_DETECTED = 2\n",
"\n",
"structlog.configure(\n",
" processors=[\n",
" structlog.contextvars.merge_contextvars,\n",
" structlog.processors.add_log_level,\n",
" structlog.processors.StackInfoRenderer(),\n",
" structlog.dev.set_exc_info,\n",
" structlog.processors.TimeStamper(fmt=\"iso\", utc=True),\n",
" structlog.dev.ConsoleRenderer()\n",
" ],\n",
" wrapper_class=structlog.make_filtering_bound_logger(logging.NOTSET),\n",
" context_class=dict,\n",
" logger_factory=structlog.PrintLoggerFactory(),\n",
" cache_logger_on_first_use=False\n",
")\n",
"logger = structlog.get_logger()\n",
"\n",
"logger.info(f\"{torch.__version__ = }\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def print_allocated_memory(i=None, iters=None):\n",
" log_line = \"Memory info\"\n",
" if i is not None and iters is not None:\n",
" log_line += f\" ({i}/{iters})\"\n",
" logger.info(log_line, current_allocated_memory=torch.mps.current_allocated_memory(), driver_allocated_memory=torch.mps.driver_allocated_memory())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def empty_cache():\n",
" print_allocated_memory()\n",
" torch.mps.empty_cache()\n",
" time.sleep(5)\n",
" logger.info(f\"MPS cache cleared.\")\n",
" print_allocated_memory()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def benchmark(\n",
" model,\n",
" input,\n",
" should_backward=False,\n",
" should_print_model=True,\n",
" should_empty_cache=True,\n",
" debug_n_iters=100,\n",
" memory_threshold=50 * 1024**3,\n",
" iters=1000,\n",
" is_torch_model=True):\n",
" ret = NO_MEMLEAK_DETECTED\n",
"\n",
" timings = []\n",
" \n",
" if is_torch_model:\n",
" model = model.to(DEVICE)\n",
" input = input.to(DEVICE)\n",
" \n",
" if should_print_model:\n",
" print(model)\n",
"\n",
" if should_empty_cache:\n",
" empty_cache()\n",
"\n",
" logger.info(\"Entering benchmark loop.\")\n",
" for i in range(1, iters+1):\n",
" start_time = time.time()\n",
" output = model(input)\n",
" if should_backward:\n",
" loss = output.sum()\n",
" loss.backward()\n",
" end_time = time.time()\n",
" elapsed_time = end_time - start_time\n",
" timings.append(elapsed_time)\n",
" \n",
" if torch.mps.driver_allocated_memory() > memory_threshold:\n",
" logger.warning(f\"torch.mps.driver_allocated_memory() > {memory_threshold/1024**3} GiB threshold reached.\")\n",
" ret = MEMLEAK_DETECTED\n",
" break\n",
" if (i == 1 or i%debug_n_iters == 0) and i != iters:\n",
" print_allocated_memory(i, iters)\n",
" print_allocated_memory(i, iters)\n",
" logger.info(\"Exited benchmark loop.\")\n",
" timings = torch.Tensor(timings)\n",
" logger.info(f\"Timings\", sum=torch.sum(timings).item(), mean=torch.mean(timings).item(), std=torch.std(timings).item())\n",
" return ret"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Our familiar example"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): Linear(in_features=256, out_features=256, bias=True)\n",
")\n",
"\u001b[2m2024-08-17T00:14:09.160201Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m537134080\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m547061760\u001b[0m\n",
"\u001b[2m2024-08-17T00:14:14.165730Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMPS cache cleared. \u001b[0m\n",
"\u001b[2m2024-08-17T00:14:14.167033Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m537134080\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m547045376\u001b[0m\n",
"\u001b[2m2024-08-17T00:14:14.168815Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mEntering benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-17T00:14:14.239500Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (1/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m1074004992\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1639694336\u001b[0m\n",
"\u001b[2m2024-08-17T00:14:15.197293Z\u001b[0m [\u001b[33m\u001b[1mwarning \u001b[0m] \u001b[1mtorch.mps.driver_allocated_memory() > 50.0 GiB threshold reached.\u001b[0m\n",
"\u001b[2m2024-08-17T00:14:15.197684Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (97/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m1074004992\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m53716172800\u001b[0m\n",
"\u001b[2m2024-08-17T00:14:15.197890Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mExited benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-17T00:14:15.198880Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mTimings \u001b[0m \u001b[36mmean\u001b[0m=\u001b[35m0.010581390000879765\u001b[0m \u001b[36mstd\u001b[0m=\u001b[35m0.01245816983282566\u001b[0m \u001b[36msum\u001b[0m=\u001b[35m1.0263948440551758\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N, C, H, W = 64, 32, 256, 256\n",
"\n",
"model = torch.nn.Sequential(\n",
" torch.nn.Linear(H, W),\n",
")\n",
"\n",
"inputs = torch.rand(N, C, H, W).to(DEVICE)\n",
"model.to(DEVICE)\n",
"\n",
"benchmark(model, inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Reduce the channels and issue goes away"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): Linear(in_features=256, out_features=256, bias=True)\n",
")\n",
"\u001b[2m2024-08-14T00:28:17.865508Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m537134080\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m80559718400\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:22.888176Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMPS cache cleared. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:22.889216Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m537134080\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m79485976576\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:22.889656Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mEntering benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:23.943998Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (1/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m570688512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1654407168\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:23.949310Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (100/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m570688512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1654407168\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:23.954487Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (200/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m570688512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1654407168\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:24.043924Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (300/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m570688512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1654407168\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:24.105883Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (400/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m570688512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1654407168\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:24.167636Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (500/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m570688512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1654407168\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:24.229711Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (600/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m570688512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1654407168\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:24.292149Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (700/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m570688512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1654407168\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:24.354226Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (800/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m570688512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1654407168\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:24.415980Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (900/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m570688512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1654407168\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:24.477972Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (1000/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m570688512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1654407168\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:24.478600Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mExited benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:24.479454Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mTimings \u001b[0m \u001b[36mmean\u001b[0m=\u001b[35m0.0015816929517313838\u001b[0m \u001b[36mstd\u001b[0m=\u001b[35m0.0333653949201107\u001b[0m \u001b[36msum\u001b[0m=\u001b[35m1.5816929340362549\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N, C, H, W = 64, 2, 256, 256\n",
"\n",
"model = torch.nn.Sequential(\n",
" torch.nn.Linear(H, W),\n",
")\n",
"\n",
"inputs = torch.rand(N, C, H, W).to(DEVICE)\n",
"model.to(DEVICE)\n",
"\n",
"benchmark(model, inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Increase the channels by 1 and we start seeing leaked bytes\n",
"\n",
"This section contains multiple examples, each increasing the channel count by 1."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): Linear(in_features=256, out_features=256, bias=True)\n",
")\n",
"\u001b[2m2024-08-14T00:28:24.636505Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m50594816\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1654407168\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:29.650194Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMPS cache cleared. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:29.651909Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m50594816\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1117536256\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:29.654330Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mEntering benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:29.668110Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (1/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m100926464\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1153220608\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:29.708288Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (100/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m100926464\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m6136053760\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:29.738447Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (200/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m100926464\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m11169218560\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:29.890037Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (300/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m100926464\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m16202383360\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:30.047549Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (400/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m100926464\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m21235548160\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:30.205191Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (500/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m100926464\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m26268712960\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:30.366415Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (600/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m100926464\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m31301877760\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:30.524351Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (700/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m100926464\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m36335042560\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:30.681660Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (800/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m100926464\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m41368207360\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:30.841429Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (900/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m100926464\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m46401372160\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:31.031374Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (1000/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m100926464\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m51434536960\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:31.031864Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mExited benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:31.032647Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mTimings \u001b[0m \u001b[36mmean\u001b[0m=\u001b[35m0.0013680877164006233\u001b[0m \u001b[36mstd\u001b[0m=\u001b[35m0.0017444362165406346\u001b[0m \u001b[36msum\u001b[0m=\u001b[35m1.3680877685546875\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N, C, H, W = 64, 3, 256, 256\n",
"\n",
"model = torch.nn.Sequential(\n",
" torch.nn.Linear(H, W),\n",
")\n",
"\n",
"inputs = torch.rand(N, C, H, W).to(DEVICE)\n",
"model.to(DEVICE)\n",
"\n",
"benchmark(model, inputs)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): Linear(in_features=256, out_features=256, bias=True)\n",
")\n",
"\u001b[2m2024-08-14T00:28:31.376249Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m67372032\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m51434536960\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:36.381718Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMPS cache cleared. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:36.382303Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m67372032\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m51434536960\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:36.382678Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mEntering benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:37.030970Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (1/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m134480896\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1170030592\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:37.062819Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (100/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m134480896\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m7813808128\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:37.222218Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (200/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m134480896\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m14524694528\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:37.431800Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (300/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m134480896\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m21235580928\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:37.640647Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (400/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m134480896\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m27946467328\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:37.849484Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (500/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m134480896\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m34657353728\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:38.058709Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (600/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m134480896\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m41368240128\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:38.286461Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (700/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m134480896\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m48079126528\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:38.524176Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (800/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m134480896\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m54790012928\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:38.749961Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (900/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m134480896\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m61500899328\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:38.974781Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (1000/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m134480896\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m68211785728\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:38.975512Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mExited benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:38.976199Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mTimings \u001b[0m \u001b[36mmean\u001b[0m=\u001b[35m0.0025823896285146475\u001b[0m \u001b[36mstd\u001b[0m=\u001b[35m0.02049700915813446\u001b[0m \u001b[36msum\u001b[0m=\u001b[35m2.5823895931243896\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N, C, H, W = 64, 4, 256, 256\n",
"\n",
"model = torch.nn.Sequential(\n",
" torch.nn.Linear(H, W),\n",
")\n",
"\n",
"inputs = torch.rand(N, C, H, W).to(DEVICE)\n",
"model.to(DEVICE)\n",
"\n",
"benchmark(model, inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Eventually we hit our 75 GiB leak threshold as before."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): Linear(in_features=256, out_features=256, bias=True)\n",
")\n",
"\u001b[2m2024-08-14T00:28:39.280440Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m84149248\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m68211785728\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:44.281594Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMPS cache cleared. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:44.284718Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m84149248\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m68211785728\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:44.285311Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mEntering benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:45.156796Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (1/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m168035328\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1186840576\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:45.194586Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (100/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m168035328\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m9491562496\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:45.423609Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (200/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m168035328\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m17880170496\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:45.744337Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (300/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m168035328\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m26268778496\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:46.070889Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (400/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m168035328\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m34657386496\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:46.412304Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (500/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m168035328\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m43045994496\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:46.779047Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (600/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m168035328\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m51434602496\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:47.138406Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (700/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m168035328\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m59823210496\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:47.483879Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (800/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m168035328\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m68211818496\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:47.829556Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (900/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m168035328\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m76600426496\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:48.007905Z\u001b[0m [\u001b[33m\u001b[1mwarning \u001b[0m] \u001b[1mtorch.mps.driver_allocated_memory() > 75.0 GiB threshold reached.\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:48.008353Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (947/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m168035328\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m80543072256\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:48.008606Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mExited benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:48.009140Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mTimings \u001b[0m \u001b[36mmean\u001b[0m=\u001b[35m0.003923389129340649\u001b[0m \u001b[36mstd\u001b[0m=\u001b[35m0.028357602655887604\u001b[0m \u001b[36msum\u001b[0m=\u001b[35m3.715449333190918\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N, C, H, W = 64, 5, 256, 256\n",
"\n",
"model = torch.nn.Sequential(\n",
" torch.nn.Linear(H, W),\n",
")\n",
"\n",
"inputs = torch.rand(N, C, H, W).to(DEVICE)\n",
"model.to(DEVICE)\n",
"\n",
"benchmark(model, inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Different topology again\n",
"\n",
"1st example produces no leaks"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): Linear(in_features=512, out_features=512, bias=True)\n",
")\n",
"\u001b[2m2024-08-14T00:28:48.495006Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m85196800\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m80543072256\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:53.496353Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMPS cache cleared. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:53.498689Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m85196800\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m80543072256\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:53.499654Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mEntering benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:54.507564Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (1/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m152305664\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1151221760\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:54.512561Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (100/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m152305664\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1151221760\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:54.691710Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (200/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m152305664\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1151221760\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:54.895429Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (300/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m152305664\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1151221760\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:55.098406Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (400/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m152305664\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1151221760\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:55.301464Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (500/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m152305664\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1151221760\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:55.504659Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (600/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m152305664\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1151221760\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:55.707588Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (700/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m152305664\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1151221760\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:55.910826Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (800/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m152305664\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1151221760\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:56.113348Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (900/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m152305664\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1151221760\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:56.316739Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (1000/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m152305664\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1151221760\u001b[0m\n",
"\u001b[2m2024-08-14T00:28:56.317301Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mExited benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-14T00:28:56.317891Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mTimings \u001b[0m \u001b[36mmean\u001b[0m=\u001b[35m0.002810516394674778\u001b[0m \u001b[36mstd\u001b[0m=\u001b[35m0.031870286911726\u001b[0m \u001b[36msum\u001b[0m=\u001b[35m2.810516357421875\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N, C, H, W = 64, 1, 512, 512\n",
"\n",
"model = torch.nn.Sequential(\n",
" torch.nn.Linear(H, W),\n",
")\n",
"\n",
"inputs = torch.rand(N, C, H, W).to(DEVICE)\n",
"model.to(DEVICE)\n",
"\n",
"benchmark(model, inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Increase H and W by 1 from the example above and we start seeing leaks again."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): Linear(in_features=513, out_features=513, bias=True)\n",
")\n",
"\u001b[2m2024-08-14T00:28:56.576166Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m68686336\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1151221760\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:01.577181Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMPS cache cleared. \u001b[0m\n",
"\u001b[2m2024-08-14T00:29:01.578925Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m68686336\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1151221760\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:01.579988Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mEntering benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-14T00:29:01.597519Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (1/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m136057600\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m1191100416\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:01.630834Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (100/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m136057600\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m8042496000\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:01.882298Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (200/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m136057600\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m14963097600\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:02.205055Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (300/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m136057600\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m21883699200\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:02.527381Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (400/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m136057600\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m28804300800\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:02.849864Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (500/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m136057600\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m35724902400\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:03.172567Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (600/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m136057600\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m42645504000\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:03.494945Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (700/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m136057600\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m49566105600\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:03.817606Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (800/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m136057600\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m56486707200\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:04.140444Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (900/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m136057600\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m63407308800\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:04.462729Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (1000/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m136057600\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m70327910400\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:04.463234Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mExited benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-14T00:29:04.463942Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mTimings \u001b[0m \u001b[36mmean\u001b[0m=\u001b[35m0.002869668649509549\u001b[0m \u001b[36mstd\u001b[0m=\u001b[35m0.0027064078021794558\u001b[0m \u001b[36msum\u001b[0m=\u001b[35m2.86966872215271\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N, C, H, W = 64, 1, 513, 513\n",
"\n",
"model = torch.nn.Sequential(\n",
" torch.nn.Linear(H, W),\n",
")\n",
"\n",
"inputs = torch.rand(N, C, H, W).to(DEVICE)\n",
"model.to(DEVICE)\n",
"\n",
"benchmark(model, inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Different topology again\n",
"\n",
"Increasing H and W in this case does not produce leaks"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): Linear(in_features=8192, out_features=8192, bias=True)\n",
")\n",
"\u001b[2m2024-08-14T00:29:04.894810Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m537133056\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m70327910400\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:09.900306Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMPS cache cleared. \u001b[0m\n",
"\u001b[2m2024-08-14T00:29:09.901889Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m537133056\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m70327910400\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:09.903028Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mEntering benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-14T00:29:09.920765Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (1/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m805568512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m70013337600\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:13.998453Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (100/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m805568512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m2157887488\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:25.232645Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (200/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m805568512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m2157887488\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:38.000769Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (300/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m805568512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m2157887488\u001b[0m\n",
"\u001b[2m2024-08-14T00:29:51.836736Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (400/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m805568512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m2157887488\u001b[0m\n",
"\u001b[2m2024-08-14T00:30:05.669416Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (500/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m805568512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m2157887488\u001b[0m\n",
"\u001b[2m2024-08-14T00:30:19.253773Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (600/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m805568512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m2157887488\u001b[0m\n",
"\u001b[2m2024-08-14T00:30:32.933473Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (700/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m805568512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m2157887488\u001b[0m\n",
"\u001b[2m2024-08-14T00:30:46.595723Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (800/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m805568512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m2157887488\u001b[0m\n",
"\u001b[2m2024-08-14T00:31:00.887105Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (900/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m805568512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m2157887488\u001b[0m\n",
"\u001b[2m2024-08-14T00:31:16.919174Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mMemory info (1000/1000) \u001b[0m \u001b[36mcurrent_allocated_memory\u001b[0m=\u001b[35m805568512\u001b[0m \u001b[36mdriver_allocated_memory\u001b[0m=\u001b[35m2157887488\u001b[0m\n",
"\u001b[2m2024-08-14T00:31:16.919621Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mExited benchmark loop. \u001b[0m\n",
"\u001b[2m2024-08-14T00:31:16.920226Z\u001b[0m [\u001b[32m\u001b[1minfo \u001b[0m] \u001b[1mTimings \u001b[0m \u001b[36mmean\u001b[0m=\u001b[35m0.12700077891349792\u001b[0m \u001b[36mstd\u001b[0m=\u001b[35m0.035352881997823715\u001b[0m \u001b[36msum\u001b[0m=\u001b[35m127.00077819824219\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N, C, H, W = 1, 1, 8192, 8192\n",
"\n",
"model = torch.nn.Sequential(\n",
" torch.nn.Linear(H, W),\n",
")\n",
"\n",
"inputs = torch.rand(N, C, H, W).to(DEVICE)\n",
"model.to(DEVICE)\n",
"\n",
"benchmark(model, inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Even with a 2 GiB tensor we don't see any leaks this time. The failure mode does not seem to be directly related to the tensor size, rather the tensor topology and size is what causes it."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# N, C, H, W = 1, 1, 2 * 8192, 2 * 8192\n",
"\n",
"# model = torch.nn.Sequential(\n",
"# torch.nn.Linear(H, W),\n",
"# )\n",
"\n",
"# inputs = torch.rand(N, C, H, W).to(DEVICE)\n",
"# model.to(DEVICE)\n",
"\n",
"# benchmark(model, inputs, iters=500)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorchdev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment