Last active
February 12, 2023 14:49
-
-
Save Wybxc/b587fa959d4351a8174c4fa8e81217ac to your computer and use it in GitHub Desktop.
将 ChatRWKV 模型转为 TorchScript 版本
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Convert RWKV PyTorch savepoint to TorchScript model. | |
| """ | |
| from typing import Final, List, Optional | |
| import click | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class LayerNorm(nn.Module): | |
| weight: nn.Parameter | |
| bias: nn.Parameter | |
| def __init__( | |
| self, | |
| weight: nn.Parameter, | |
| bias: nn.Parameter, | |
| ): | |
| super().__init__() | |
| self.weight = weight | |
| self.bias = bias | |
| def forward(self, x, n_embd: int): | |
| return F.layer_norm(x, (n_embd,), weight=self.weight, bias=self.bias) | |
| class Attention(nn.Module): | |
| time_mix_k: nn.Parameter | |
| time_mix_v: nn.Parameter | |
| time_mix_r: nn.Parameter | |
| time_first: nn.Parameter | |
| time_decay: nn.Parameter | |
| key: nn.Parameter | |
| value: nn.Parameter | |
| receptance: nn.Parameter | |
| output: nn.Parameter | |
| float_dtype: Final[torch.dtype] | |
| def __init__( | |
| self, | |
| time_mix_k: nn.Parameter, | |
| time_mix_v: nn.Parameter, | |
| time_mix_r: nn.Parameter, | |
| time_first: nn.Parameter, | |
| time_decay: nn.Parameter, | |
| key: nn.Parameter, | |
| value: nn.Parameter, | |
| receptance: nn.Parameter, | |
| output: nn.Parameter, | |
| float_dtype: torch.dtype, | |
| ): | |
| super().__init__() | |
| self.time_mix_k = time_mix_k | |
| self.time_mix_v = time_mix_v | |
| self.time_mix_r = time_mix_r | |
| self.time_first = time_first | |
| self.time_decay = time_decay | |
| self.key = key | |
| self.value = value | |
| self.receptance = receptance | |
| self.output = output | |
| self.float_dtype = float_dtype | |
| def SA_one(self, x, state, i: int): | |
| xx = state[5 * i + 1].to(dtype=self.float_dtype) | |
| xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) | |
| xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) | |
| xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) | |
| state[5 * i + 1] = x.float() | |
| r = torch.sigmoid(xr @ self.receptance) | |
| k = (xk @ self.key).float() | |
| v = (xv @ self.value).float() | |
| aa = state[5 * i + 2] | |
| bb = state[5 * i + 3] | |
| pp = state[5 * i + 4] | |
| ww = self.time_first + k | |
| p = torch.maximum(pp, ww) | |
| e1 = torch.exp(pp - p) | |
| e2 = torch.exp(ww - p) | |
| a = e1 * aa + e2 * v | |
| b = e1 * bb + e2 | |
| ww = pp + self.time_decay | |
| p = torch.maximum(ww, k) | |
| e1 = torch.exp(ww - p) | |
| e2 = torch.exp(k - p) | |
| state[5 * i + 2] = e1 * aa + e2 * v | |
| state[5 * i + 3] = e1 * bb + e2 | |
| state[5 * i + 4] = p | |
| wkv = (a / b).to(dtype=self.float_dtype) | |
| return (r * wkv) @ self.output | |
| def SA_seq(self, x, state, i: int): | |
| xx = torch.cat( | |
| (state[5 * i + 1].to(dtype=self.float_dtype).unsqueeze(0), x[:-1, :]) | |
| ) | |
| xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) | |
| xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) | |
| xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) | |
| state[5 * i + 1] = x[-1, :].float() | |
| r = torch.sigmoid(xr @ self.receptance) | |
| k = (xk @ self.key).float() | |
| v = (xv @ self.value).float() | |
| aa = state[5 * i + 2] | |
| bb = state[5 * i + 3] | |
| pp = state[5 * i + 4] | |
| T = x.shape[0] | |
| for t in range(T): | |
| ww = self.time_first + k[t] | |
| p = torch.maximum(pp, ww) | |
| e1 = torch.exp(pp - p) | |
| e2 = torch.exp(ww - p) | |
| a = e1 * aa + e2 * v[t] | |
| b = e1 * bb + e2 | |
| ww = pp + self.time_decay | |
| p = torch.maximum(ww, k[t]) | |
| e1 = torch.exp(ww - p) | |
| e2 = torch.exp(k[t] - p) | |
| if t != T - 1: | |
| aa = e1 * aa + e2 * v[t] | |
| bb = e1 * bb + e2 | |
| pp = p | |
| else: | |
| state[5 * i + 2] = e1 * aa + e2 * v[t] | |
| state[5 * i + 3] = e1 * bb + e2 | |
| state[5 * i + 4] = p | |
| xx[t] = (a / b).to(dtype=self.float_dtype) | |
| return (r * xx) @ self.output | |
| def forward(self, x, state, i: int, seq_mode: bool): | |
| return self.SA_seq(x, state, i) if seq_mode else self.SA_one(x, state, i) | |
| class FeedForward(nn.Module): | |
| time_mix_k: nn.Parameter | |
| time_mix_r: nn.Parameter | |
| key: nn.Parameter | |
| value: nn.Parameter | |
| receptance: nn.Parameter | |
| float_dtype: Final[torch.dtype] | |
| def __init__( | |
| self, | |
| time_mix_k: nn.Parameter, | |
| time_mix_r: nn.Parameter, | |
| key: nn.Parameter, | |
| value: nn.Parameter, | |
| receptance: nn.Parameter, | |
| float_dtype: torch.dtype, | |
| ): | |
| super().__init__() | |
| self.time_mix_k = time_mix_k | |
| self.time_mix_r = time_mix_r | |
| self.key = key | |
| self.value = value | |
| self.receptance = receptance | |
| self.float_dtype = float_dtype | |
| def FF_one(self, x, state, i: int): | |
| xx = state[5 * i + 0].to(dtype=self.float_dtype) | |
| xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) | |
| xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) | |
| state[5 * i + 0] = x.float() | |
| r = torch.sigmoid(xr @ self.receptance) | |
| k = torch.square(torch.relu(xk @ self.key)) | |
| kv = k @ self.value | |
| return r * kv | |
| def FF_seq(self, x, state, i: int): | |
| xx = torch.cat( | |
| (state[5 * i + 0].to(dtype=self.float_dtype).unsqueeze(0), x[:-1, :]) | |
| ) | |
| xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) | |
| xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) | |
| state[5 * i + 0] = x[-1, :].float() | |
| r = torch.sigmoid(xr @ self.receptance) | |
| k = torch.square(torch.relu(xk @ self.key)) | |
| kv = k @ self.value | |
| return r * kv | |
| def forward(self, x, state, i: int, seq_mode: bool): | |
| return self.FF_seq(x, state, i) if seq_mode else self.FF_one(x, state, i) | |
| class Block(nn.Module): | |
| def __init__( | |
| self, att: Attention, ffn: FeedForward, ln1: LayerNorm, ln2: LayerNorm | |
| ): | |
| super().__init__() | |
| self.add_module("att", att) | |
| self.add_module("ffn", ffn) | |
| self.add_module("ln1", ln1) | |
| self.add_module("ln2", ln2) | |
| def forward(self, x, state, i: int, n_embed: int, seq_mode: bool): | |
| x = x + self.att.forward( | |
| self.ln1.forward(x, n_embed), | |
| state, | |
| i, | |
| seq_mode=seq_mode, | |
| ) | |
| x = x + self.ffn.forward( | |
| self.ln2.forward(x, n_embed), | |
| state, | |
| i, | |
| seq_mode=seq_mode, | |
| ) | |
| return x | |
| class RWKV(nn.Module): | |
| head: nn.Parameter | |
| emb: nn.Parameter | |
| n_embd: Final[int] | |
| float_dtype: Final[torch.dtype] | |
| device: Final[torch.device] | |
| RWKV_RESCALE_LAYER: Final[int] = 6 | |
| def __init__( | |
| self, | |
| blocks: nn.ModuleList, | |
| head: nn.Parameter, | |
| emb: nn.Parameter, | |
| ln0: LayerNorm, | |
| ln_out: LayerNorm, | |
| n_embd: int, | |
| float_dtype: torch.dtype, | |
| device: torch.device, | |
| ): | |
| super().__init__() | |
| self.add_module("blocks", blocks) | |
| self.head = head | |
| self.emb = emb | |
| self.add_module("ln0", ln0) | |
| self.add_module("ln_out", ln_out) | |
| self.n_embd = n_embd | |
| self.float_dtype = float_dtype | |
| self.device = device | |
| def forward( | |
| self, | |
| tokens: List[int], | |
| state, | |
| preprocess_only: bool = False, | |
| ): | |
| seq_mode = len(tokens) > 1 | |
| x = self.emb[tokens] if seq_mode else self.emb[tokens[-1]] | |
| x = x.to(device=self.device) | |
| for i, block in enumerate(self.blocks): | |
| x = block.forward(x, state, i, self.n_embd, seq_mode) | |
| if ( | |
| self.float_dtype == torch.float16 | |
| and (i + 1) % self.RWKV_RESCALE_LAYER == 0 | |
| ): | |
| x = x / 2 | |
| if preprocess_only: | |
| return torch.empty(1), state | |
| x = self.ln_out.forward(x[-1, :] if seq_mode else x, self.n_embd) | |
| x = self.head @ x | |
| return x.float(), state | |
| class RWKVJIT(nn.Module): | |
| RWKV_RESCALE_LAYER = 6 | |
| def __init__( | |
| self, | |
| *, | |
| model_path: str, | |
| float_dtype: torch.dtype, | |
| device: torch.device, | |
| ): | |
| super().__init__() | |
| self.float_dtype = float_dtype | |
| self.device = device | |
| with torch.no_grad(): | |
| w = torch.load(model_path, map_location="cpu") | |
| n_embd = w["emb.weight"].shape[1] | |
| n_layer = 0 | |
| keys = list(w.keys()) | |
| print_need_newline = False | |
| # print(keys) | |
| for x in keys: | |
| w[x].requires_grad = False | |
| if x == "emb.weight" or "ln0" in x: | |
| continue | |
| block_id = int(x.split(".")[1]) if ("blocks." in x) else 0 | |
| n_layer = max(n_layer, block_id + 1) | |
| if ".time_" in x: | |
| w[x] = w[x].squeeze() | |
| if ( | |
| "key.weight" in x | |
| or "value.weight" in x | |
| or "receptance.weight" in x | |
| or "output.weight" in x | |
| ): | |
| w[x] = w[x].t() | |
| if ".time_decay" in x: | |
| w[x] = w[x].float() | |
| w[x] = -torch.exp(w[x]) | |
| elif ".time_first" in x: | |
| w[x] = w[x].float() | |
| else: | |
| w[x] = w[x].to(dtype=self.float_dtype) | |
| if float_dtype == torch.float16: | |
| if "att.output.weight" in x: | |
| w[x] = w[x] / (2 ** int(block_id // self.RWKV_RESCALE_LAYER)) | |
| if "ffn.value.weight" in x: | |
| w[x] = w[x] / (2 ** int(block_id // self.RWKV_RESCALE_LAYER)) | |
| w[x] = w[x].to(device=device) | |
| shape = w[x].shape | |
| shape = [i for i in shape if i != 1] | |
| if len(shape) > 1: | |
| shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" | |
| else: | |
| shape = f" {str(shape[0]).rjust(5)} " | |
| if block_id == 0: | |
| if print_need_newline: | |
| print("\n", end="") | |
| print_need_newline = False | |
| print( | |
| x.ljust(32), | |
| str(w[x].dtype).replace("torch.", "").ljust(10), | |
| w[x].device, | |
| shape, | |
| ) | |
| else: | |
| print_need_newline = True | |
| print(".", end="", flush=True) | |
| print() | |
| print(f"{n_layer=} {n_embd=}") | |
| self.n_layer = n_layer | |
| self.n_embd = n_embd | |
| emb = w["emb.weight"] | |
| ln_out = LayerNorm(w["ln_out.weight"], w["ln_out.bias"]) | |
| ln0 = LayerNorm(w["blocks.0.ln0.weight"], w["blocks.0.ln0.bias"]) | |
| head = w["head.weight"] | |
| blocks = nn.ModuleList( | |
| Block( | |
| att=Attention( | |
| time_mix_k=w[f"blocks.{i}.att.time_mix_k"], | |
| time_mix_v=w[f"blocks.{i}.att.time_mix_v"], | |
| time_mix_r=w[f"blocks.{i}.att.time_mix_r"], | |
| time_first=w[f"blocks.{i}.att.time_first"], | |
| time_decay=w[f"blocks.{i}.att.time_decay"], | |
| key=w[f"blocks.{i}.att.key.weight"], | |
| value=w[f"blocks.{i}.att.value.weight"], | |
| receptance=w[f"blocks.{i}.att.receptance.weight"], | |
| output=w[f"blocks.{i}.att.output.weight"], | |
| float_dtype=self.float_dtype, | |
| ), | |
| ffn=FeedForward( | |
| time_mix_k=w[f"blocks.{i}.ffn.time_mix_k"], | |
| time_mix_r=w[f"blocks.{i}.ffn.time_mix_r"], | |
| key=w[f"blocks.{i}.ffn.key.weight"], | |
| value=w[f"blocks.{i}.ffn.value.weight"], | |
| receptance=w[f"blocks.{i}.ffn.receptance.weight"], | |
| float_dtype=self.float_dtype, | |
| ), | |
| ln1=LayerNorm(w[f"blocks.{i}.ln1.weight"], w[f"blocks.{i}.ln1.bias"]), | |
| ln2=LayerNorm(w[f"blocks.{i}.ln2.weight"], w[f"blocks.{i}.ln2.bias"]), | |
| ) | |
| for i in range(self.n_layer) | |
| ) | |
| with torch.no_grad(): # precompute embedding | |
| x = ln0.forward(emb, n_embd) | |
| emb = x.to(dtype=self.float_dtype) | |
| self.add_module( | |
| "model", | |
| RWKV( | |
| blocks=blocks, | |
| head=head, | |
| emb=emb, | |
| ln0=ln0, | |
| ln_out=ln_out, | |
| n_embd=self.n_embd, | |
| float_dtype=self.float_dtype, | |
| device=self.device, | |
| ).to(device=device), | |
| ) | |
| def forward( | |
| self, | |
| tokens: List[int], | |
| state: Optional[torch.Tensor], | |
| preprocess_only: bool = False, | |
| ): | |
| with torch.no_grad(): | |
| if state is None: | |
| state = torch.zeros(self.n_layer * 5, self.n_embd, device=self.device) | |
| for i in range(self.n_layer): | |
| state[5 * i + 4] -= 1e30 | |
| x, state = self.model.forward( | |
| tokens=tokens, | |
| state=state, | |
| preprocess_only=preprocess_only, | |
| ) | |
| return x, state | |
| @click.command() | |
| @click.option( | |
| "--float-mode", | |
| type=click.Choice(["fp32", "fp16", "bf16"]), | |
| default="fp16", | |
| ) | |
| @click.option("--device", type=click.Choice(["cpu", "cuda"])) | |
| @click.option("--opt/--no-opt", default=True, help="Optimize TorchScript model") | |
| @click.argument("model_path", type=click.Path(exists=True)) | |
| @click.argument("output_path", type=click.Path()) | |
| def convert(float_mode, device, opt, model_path, output_path): | |
| float_dtypes = { | |
| "fp32": torch.float32, | |
| "fp16": torch.float16, | |
| "bf32": torch.bfloat16, | |
| } | |
| print(f"○ Loading model from {model_path}...") | |
| print() | |
| model = RWKVJIT( | |
| model_path=model_path, | |
| float_dtype=float_dtypes[float_mode], | |
| device=torch.device(device), | |
| ) | |
| print() | |
| print("○ Converting model to TorchScript...") | |
| model = torch.jit.script(model) | |
| if float_mode == "bf16": | |
| model = model.bfloat16() | |
| elif float_mode == "fp16": | |
| model = model.half() | |
| else: | |
| model = model.float() | |
| if device == "cuda": | |
| model = model.cuda() | |
| if opt: | |
| print("○ Optimizing model for inference...") | |
| model = torch.jit.optimize_for_inference(model) | |
| print(f"○ Saving model to {output_path}...") | |
| model.save(output_path) | |
| print("○ Done.") | |
| print() | |
| print("○ Note: 脚本编写时的 ChatRWKV 版本为:git+d42e9e7,其他版本可能不兼容。") | |
| if __name__ == "__main__": | |
| convert() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment