Last active
May 24, 2024 13:51
-
-
Save KeremTurgutlu/a99e138e7fca7c9feb6cc9b74394b89e to your computer and use it in GitHub Desktop.
test_triton_mm.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "f7e69d06-de3c-487c-ad62-7aebce775e15", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "04d16c8e-bfba-4e6b-9dd9-58daae15135e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from vllm.model_executor.layers.quantization.triton_mm import triton_mixed_mm" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "b349f02c-7df3-4942-861e-523f00e34436", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from hqq.core.quantize import HQQLinear, BaseQuantizeConfig, Quantizer" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "7335bab7-6909-45cf-b623-6468052940c8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def pack_2xint4(t):\n", | |
| " \"\"\"\n", | |
| " The packing format is such that consecutive rows are packed into a lower / upper bits\n", | |
| " E.g.,\n", | |
| " Original, unpacked B (dtype i8):\n", | |
| " [\n", | |
| " [0, 1, 2, 3]\n", | |
| " [4, 5, 6, 7]\n", | |
| " [8, 9, 10, 11]\n", | |
| " [12, 13, 14, 15]\n", | |
| " ]\n", | |
| " Packed B:\n", | |
| " [\n", | |
| " [0|4, 1|5, 2|6, 3|7]\n", | |
| " [8|12, 9|13, 10|14, 11|15]\n", | |
| " ]\n", | |
| " (Note each entry in `Packed B` is shown lsb->msb)\n", | |
| " \"\"\"\n", | |
| " assert t.dtype == torch.int8 or t.dtype == torch.uint8\n", | |
| " t = t.reshape(t.shape[0] // 2, 2, t.shape[1]).permute(1, 0, 2)\n", | |
| " return (t[0] & 0xF) | (t[1] << 4)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "1ad339f8-d5df-4740-81c8-61f46eba450b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def patch_hqq_to_tritonmm(layer, patch_param):\n", | |
| " if(isinstance(layer, HQQLinear)):\n", | |
| "\n", | |
| " #Handle no grouping case\n", | |
| " shape = layer.meta['shape']\n", | |
| " layer.group_size = layer.quant_config['weight_quant_params']['group_size']\n", | |
| " if(layer.group_size is None):\n", | |
| " layer.group_size = shape[1] \n", | |
| "\n", | |
| " #Update scale/zero\n", | |
| " M, N = shape[::-1]\n", | |
| " layer.meta ['scale'] = layer.meta ['scale'].reshape(N, -1).T \n", | |
| " layer.meta ['zero'] = layer.meta ['zero'].reshape(N, -1).T \n", | |
| "\n", | |
| " #Repack\n", | |
| " layer.W_q.data = pack_2xint4(Quantizer.unpack[layer.meta ['packing']](layer.W_q).reshape(layer.meta [\"shape\"]).T).data \n", | |
| "\n", | |
| " #Set pred vals\n", | |
| " layer.fp8_fast_accum = True #False \n", | |
| " layer.kernel_type = \"max_autotune\" #max_autotune\n", | |
| "\n", | |
| " def matmul_tritonmm(self, x, transpose=True):\n", | |
| "\n", | |
| " out_dim = self.meta['shape'][0] if (transpose) else self.meta['shape'][1]\n", | |
| " out = triton_mixed_mm(x.view([-1, x.shape[-1]]),\n", | |
| " self.W_q,\n", | |
| " self.meta[\"scale\"],\n", | |
| " self.meta[\"zero\"],\n", | |
| " group_size=self.group_size,\n", | |
| " fp8_fast_accum=self.fp8_fast_accum,\n", | |
| " kernel_type=self.kernel_type,\n", | |
| " transposed=not transpose,\n", | |
| " ).view([x.shape[0], x.shape[1], out_dim])\n", | |
| "\n", | |
| " return out \n", | |
| "\n", | |
| " def forward_tritonmm_backprop(self, x):\n", | |
| " return HQQMatmulNoCacheMul.apply(x, self.matmul, self.bias)\n", | |
| "\n", | |
| " def forward_tritonmm_forward(self, x):\n", | |
| " out = self.matmul(x)\n", | |
| " if(self.bias is not None):\n", | |
| " out += self.bias\n", | |
| " return out \n", | |
| "\n", | |
| "\n", | |
| " layer.matmul = lambda x, transpose: matmul_tritonmm(layer, x, transpose)\n", | |
| " layer.forward = lambda x: forward_tritonmm_backprop(layer, x)\n", | |
| "\n", | |
| "\n", | |
| " return layer" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "6ff9ca4b-9ddf-40a5-b185-c3ec886f02ed", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "quant_config = BaseQuantizeConfig(nbits=4,\n", | |
| " group_size=64, \n", | |
| " quant_zero=False,\n", | |
| " quant_scale=False,\n", | |
| " offload_meta=False,\n", | |
| " view_as_float=False, \n", | |
| " axis=1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "id": "8fadd71a-7c24-4eed-a8ad-7c60af6284e6", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "q_weight = torch.randn(4096, 4096) # output x input\n", | |
| "k_weight = torch.randn(1024, 4096)\n", | |
| "v_weight = torch.randn(1024, 4096)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "id": "4d9529d0-8b1d-4b31-9b53-7422f8243136", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "dtype = torch.bfloat16\n", | |
| "triton_params = {}\n", | |
| "for name,p in [(\"q\", q_weight), (\"k\",k_weight), (\"v\",v_weight)]:\n", | |
| " m = torch.nn.Linear(*p.T.shape, bias=False)\n", | |
| " m.weight.data.copy_(p)\n", | |
| " dummy_hqq_linear = HQQLinear(m, quant_config, compute_dtype=dtype)\n", | |
| " patched_hqq_linear = patch_hqq_to_tritonmm(dummy_hqq_linear, None)\n", | |
| " triton_params[name] = {\"Wq\":patched_hqq_linear.W_q, \n", | |
| " \"scale\":patched_hqq_linear.meta['scale'], \n", | |
| " \"zero\":patched_hqq_linear.meta['zero']}" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 30, | |
| "id": "2f7f93a3-5118-4766-9f24-5e769ff3841e", | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "qkv_weight = torch.cat([triton_params[k][\"Wq\"] for k in [\"q\", \"k\", \"v\"]], dim=1)\n", | |
| "qkv_scale = torch.cat([triton_params[k][\"scale\"] for k in [\"q\", \"k\", \"v\"]], dim=1)\n", | |
| "qkv_zero = torch.cat([triton_params[k][\"zero\"] for k in [\"q\", \"k\", \"v\"]], dim=1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 32, | |
| "id": "e608e1f0-3cac-4dca-b1df-bd669de6e717", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([2048, 6144])" | |
| ] | |
| }, | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "qkv_weight.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 36, | |
| "id": "c4c025e3-a171-4801-8428-9e588bd516e7", | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "x = torch.randn(16,4096, device=\"cuda\", dtype=torch.bfloat16)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 40, | |
| "id": "6dea6917-6644-41c5-af61-5f235fa653aa", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "output_qkv = triton_mixed_mm(x,\n", | |
| " qkv_weight,\n", | |
| " qkv_scale,\n", | |
| " qkv_zero,\n", | |
| " group_size=quant_config['weight_quant_params']['group_size'],\n", | |
| " fp8_fast_accum=False,\n", | |
| " kernel_type=\"compute_bound\",\n", | |
| " transposed=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 42, | |
| "id": "4bbc933e-a807-41db-992b-4bec442aadd2", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "output_q = triton_mixed_mm(x,\n", | |
| " triton_params[\"q\"][\"Wq\"],\n", | |
| " triton_params[\"q\"][\"scale\"],\n", | |
| " triton_params[\"q\"][\"zero\"],\n", | |
| " group_size=quant_config['weight_quant_params']['group_size'],\n", | |
| " fp8_fast_accum=False,\n", | |
| " kernel_type=\"compute_bound\",\n", | |
| " transposed=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 43, | |
| "id": "4fbe32cd-c5a4-4ab8-a99c-1d028e4764fa", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "False" | |
| ] | |
| }, | |
| "execution_count": 43, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.equal(output_qkv[:,:4096], output_q)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 47, | |
| "id": "ce2bc22b-9961-420f-b67b-d07fe309aa3d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "output_k = triton_mixed_mm(x,\n", | |
| " triton_params[\"k\"][\"Wq\"],\n", | |
| " triton_params[\"k\"][\"scale\"],\n", | |
| " triton_params[\"k\"][\"zero\"],\n", | |
| " group_size=quant_config['weight_quant_params']['group_size'],\n", | |
| " fp8_fast_accum=False,\n", | |
| " kernel_type=\"compute_bound\",\n", | |
| " transposed=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 52, | |
| "id": "cf9e24fe-ad1f-4c97-a0eb-7f13482d5b52", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "False" | |
| ] | |
| }, | |
| "execution_count": 52, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.equal(output_qkv[:,4096:5120], output_k)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 53, | |
| "id": "fbf45ba1-5282-4cc5-a96b-849d162f8adf", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "output_v = triton_mixed_mm(x,\n", | |
| " triton_params[\"v\"][\"Wq\"],\n", | |
| " triton_params[\"v\"][\"scale\"],\n", | |
| " triton_params[\"v\"][\"zero\"],\n", | |
| " group_size=quant_config['weight_quant_params']['group_size'],\n", | |
| " fp8_fast_accum=False,\n", | |
| " kernel_type=\"compute_bound\",\n", | |
| " transposed=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 54, | |
| "id": "c3be901d-10e8-42ba-8392-ee48ba5b8967", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "False" | |
| ] | |
| }, | |
| "execution_count": 54, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.equal(output_qkv[:,5120:], output_v)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 57, | |
| "id": "21dc7908-ce3e-4f60-a9ee-8dbd733a92ad", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "False" | |
| ] | |
| }, | |
| "execution_count": 57, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.equal(output_qkv, torch.cat([output_q, output_k, output_v], dim=1))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "a50c3def-6633-4ef8-b7f7-2bd8c8701d55", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "c01f04a9-9f90-4ad0-9ad0-ec6c74804d05", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "badb7b0f-32c2-48e4-a1fa-27254cd548c0", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.12" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment