Skip to content

Instantly share code, notes, and snippets.

@Wybxc
Last active February 12, 2023 14:49
Show Gist options
  • Select an option

  • Save Wybxc/b587fa959d4351a8174c4fa8e81217ac to your computer and use it in GitHub Desktop.

Select an option

Save Wybxc/b587fa959d4351a8174c4fa8e81217ac to your computer and use it in GitHub Desktop.
将 ChatRWKV 模型转为 TorchScript 版本
"""Convert RWKV PyTorch savepoint to TorchScript model.
"""
from typing import Final, List, Optional
import click
import torch
import torch.nn as nn
import torch.nn.functional as F
class LayerNorm(nn.Module):
weight: nn.Parameter
bias: nn.Parameter
def __init__(
self,
weight: nn.Parameter,
bias: nn.Parameter,
):
super().__init__()
self.weight = weight
self.bias = bias
def forward(self, x, n_embd: int):
return F.layer_norm(x, (n_embd,), weight=self.weight, bias=self.bias)
class Attention(nn.Module):
time_mix_k: nn.Parameter
time_mix_v: nn.Parameter
time_mix_r: nn.Parameter
time_first: nn.Parameter
time_decay: nn.Parameter
key: nn.Parameter
value: nn.Parameter
receptance: nn.Parameter
output: nn.Parameter
float_dtype: Final[torch.dtype]
def __init__(
self,
time_mix_k: nn.Parameter,
time_mix_v: nn.Parameter,
time_mix_r: nn.Parameter,
time_first: nn.Parameter,
time_decay: nn.Parameter,
key: nn.Parameter,
value: nn.Parameter,
receptance: nn.Parameter,
output: nn.Parameter,
float_dtype: torch.dtype,
):
super().__init__()
self.time_mix_k = time_mix_k
self.time_mix_v = time_mix_v
self.time_mix_r = time_mix_r
self.time_first = time_first
self.time_decay = time_decay
self.key = key
self.value = value
self.receptance = receptance
self.output = output
self.float_dtype = float_dtype
def SA_one(self, x, state, i: int):
xx = state[5 * i + 1].to(dtype=self.float_dtype)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
state[5 * i + 1] = x.float()
r = torch.sigmoid(xr @ self.receptance)
k = (xk @ self.key).float()
v = (xv @ self.value).float()
aa = state[5 * i + 2]
bb = state[5 * i + 3]
pp = state[5 * i + 4]
ww = self.time_first + k
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * v
b = e1 * bb + e2
ww = pp + self.time_decay
p = torch.maximum(ww, k)
e1 = torch.exp(ww - p)
e2 = torch.exp(k - p)
state[5 * i + 2] = e1 * aa + e2 * v
state[5 * i + 3] = e1 * bb + e2
state[5 * i + 4] = p
wkv = (a / b).to(dtype=self.float_dtype)
return (r * wkv) @ self.output
def SA_seq(self, x, state, i: int):
xx = torch.cat(
(state[5 * i + 1].to(dtype=self.float_dtype).unsqueeze(0), x[:-1, :])
)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
state[5 * i + 1] = x[-1, :].float()
r = torch.sigmoid(xr @ self.receptance)
k = (xk @ self.key).float()
v = (xv @ self.value).float()
aa = state[5 * i + 2]
bb = state[5 * i + 3]
pp = state[5 * i + 4]
T = x.shape[0]
for t in range(T):
ww = self.time_first + k[t]
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * v[t]
b = e1 * bb + e2
ww = pp + self.time_decay
p = torch.maximum(ww, k[t])
e1 = torch.exp(ww - p)
e2 = torch.exp(k[t] - p)
if t != T - 1:
aa = e1 * aa + e2 * v[t]
bb = e1 * bb + e2
pp = p
else:
state[5 * i + 2] = e1 * aa + e2 * v[t]
state[5 * i + 3] = e1 * bb + e2
state[5 * i + 4] = p
xx[t] = (a / b).to(dtype=self.float_dtype)
return (r * xx) @ self.output
def forward(self, x, state, i: int, seq_mode: bool):
return self.SA_seq(x, state, i) if seq_mode else self.SA_one(x, state, i)
class FeedForward(nn.Module):
time_mix_k: nn.Parameter
time_mix_r: nn.Parameter
key: nn.Parameter
value: nn.Parameter
receptance: nn.Parameter
float_dtype: Final[torch.dtype]
def __init__(
self,
time_mix_k: nn.Parameter,
time_mix_r: nn.Parameter,
key: nn.Parameter,
value: nn.Parameter,
receptance: nn.Parameter,
float_dtype: torch.dtype,
):
super().__init__()
self.time_mix_k = time_mix_k
self.time_mix_r = time_mix_r
self.key = key
self.value = value
self.receptance = receptance
self.float_dtype = float_dtype
def FF_one(self, x, state, i: int):
xx = state[5 * i + 0].to(dtype=self.float_dtype)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
state[5 * i + 0] = x.float()
r = torch.sigmoid(xr @ self.receptance)
k = torch.square(torch.relu(xk @ self.key))
kv = k @ self.value
return r * kv
def FF_seq(self, x, state, i: int):
xx = torch.cat(
(state[5 * i + 0].to(dtype=self.float_dtype).unsqueeze(0), x[:-1, :])
)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
state[5 * i + 0] = x[-1, :].float()
r = torch.sigmoid(xr @ self.receptance)
k = torch.square(torch.relu(xk @ self.key))
kv = k @ self.value
return r * kv
def forward(self, x, state, i: int, seq_mode: bool):
return self.FF_seq(x, state, i) if seq_mode else self.FF_one(x, state, i)
class Block(nn.Module):
def __init__(
self, att: Attention, ffn: FeedForward, ln1: LayerNorm, ln2: LayerNorm
):
super().__init__()
self.add_module("att", att)
self.add_module("ffn", ffn)
self.add_module("ln1", ln1)
self.add_module("ln2", ln2)
def forward(self, x, state, i: int, n_embed: int, seq_mode: bool):
x = x + self.att.forward(
self.ln1.forward(x, n_embed),
state,
i,
seq_mode=seq_mode,
)
x = x + self.ffn.forward(
self.ln2.forward(x, n_embed),
state,
i,
seq_mode=seq_mode,
)
return x
class RWKV(nn.Module):
head: nn.Parameter
emb: nn.Parameter
n_embd: Final[int]
float_dtype: Final[torch.dtype]
device: Final[torch.device]
RWKV_RESCALE_LAYER: Final[int] = 6
def __init__(
self,
blocks: nn.ModuleList,
head: nn.Parameter,
emb: nn.Parameter,
ln0: LayerNorm,
ln_out: LayerNorm,
n_embd: int,
float_dtype: torch.dtype,
device: torch.device,
):
super().__init__()
self.add_module("blocks", blocks)
self.head = head
self.emb = emb
self.add_module("ln0", ln0)
self.add_module("ln_out", ln_out)
self.n_embd = n_embd
self.float_dtype = float_dtype
self.device = device
def forward(
self,
tokens: List[int],
state,
preprocess_only: bool = False,
):
seq_mode = len(tokens) > 1
x = self.emb[tokens] if seq_mode else self.emb[tokens[-1]]
x = x.to(device=self.device)
for i, block in enumerate(self.blocks):
x = block.forward(x, state, i, self.n_embd, seq_mode)
if (
self.float_dtype == torch.float16
and (i + 1) % self.RWKV_RESCALE_LAYER == 0
):
x = x / 2
if preprocess_only:
return torch.empty(1), state
x = self.ln_out.forward(x[-1, :] if seq_mode else x, self.n_embd)
x = self.head @ x
return x.float(), state
class RWKVJIT(nn.Module):
RWKV_RESCALE_LAYER = 6
def __init__(
self,
*,
model_path: str,
float_dtype: torch.dtype,
device: torch.device,
):
super().__init__()
self.float_dtype = float_dtype
self.device = device
with torch.no_grad():
w = torch.load(model_path, map_location="cpu")
n_embd = w["emb.weight"].shape[1]
n_layer = 0
keys = list(w.keys())
print_need_newline = False
# print(keys)
for x in keys:
w[x].requires_grad = False
if x == "emb.weight" or "ln0" in x:
continue
block_id = int(x.split(".")[1]) if ("blocks." in x) else 0
n_layer = max(n_layer, block_id + 1)
if ".time_" in x:
w[x] = w[x].squeeze()
if (
"key.weight" in x
or "value.weight" in x
or "receptance.weight" in x
or "output.weight" in x
):
w[x] = w[x].t()
if ".time_decay" in x:
w[x] = w[x].float()
w[x] = -torch.exp(w[x])
elif ".time_first" in x:
w[x] = w[x].float()
else:
w[x] = w[x].to(dtype=self.float_dtype)
if float_dtype == torch.float16:
if "att.output.weight" in x:
w[x] = w[x] / (2 ** int(block_id // self.RWKV_RESCALE_LAYER))
if "ffn.value.weight" in x:
w[x] = w[x] / (2 ** int(block_id // self.RWKV_RESCALE_LAYER))
w[x] = w[x].to(device=device)
shape = w[x].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}"
else:
shape = f" {str(shape[0]).rjust(5)} "
if block_id == 0:
if print_need_newline:
print("\n", end="")
print_need_newline = False
print(
x.ljust(32),
str(w[x].dtype).replace("torch.", "").ljust(10),
w[x].device,
shape,
)
else:
print_need_newline = True
print(".", end="", flush=True)
print()
print(f"{n_layer=} {n_embd=}")
self.n_layer = n_layer
self.n_embd = n_embd
emb = w["emb.weight"]
ln_out = LayerNorm(w["ln_out.weight"], w["ln_out.bias"])
ln0 = LayerNorm(w["blocks.0.ln0.weight"], w["blocks.0.ln0.bias"])
head = w["head.weight"]
blocks = nn.ModuleList(
Block(
att=Attention(
time_mix_k=w[f"blocks.{i}.att.time_mix_k"],
time_mix_v=w[f"blocks.{i}.att.time_mix_v"],
time_mix_r=w[f"blocks.{i}.att.time_mix_r"],
time_first=w[f"blocks.{i}.att.time_first"],
time_decay=w[f"blocks.{i}.att.time_decay"],
key=w[f"blocks.{i}.att.key.weight"],
value=w[f"blocks.{i}.att.value.weight"],
receptance=w[f"blocks.{i}.att.receptance.weight"],
output=w[f"blocks.{i}.att.output.weight"],
float_dtype=self.float_dtype,
),
ffn=FeedForward(
time_mix_k=w[f"blocks.{i}.ffn.time_mix_k"],
time_mix_r=w[f"blocks.{i}.ffn.time_mix_r"],
key=w[f"blocks.{i}.ffn.key.weight"],
value=w[f"blocks.{i}.ffn.value.weight"],
receptance=w[f"blocks.{i}.ffn.receptance.weight"],
float_dtype=self.float_dtype,
),
ln1=LayerNorm(w[f"blocks.{i}.ln1.weight"], w[f"blocks.{i}.ln1.bias"]),
ln2=LayerNorm(w[f"blocks.{i}.ln2.weight"], w[f"blocks.{i}.ln2.bias"]),
)
for i in range(self.n_layer)
)
with torch.no_grad(): # precompute embedding
x = ln0.forward(emb, n_embd)
emb = x.to(dtype=self.float_dtype)
self.add_module(
"model",
RWKV(
blocks=blocks,
head=head,
emb=emb,
ln0=ln0,
ln_out=ln_out,
n_embd=self.n_embd,
float_dtype=self.float_dtype,
device=self.device,
).to(device=device),
)
def forward(
self,
tokens: List[int],
state: Optional[torch.Tensor],
preprocess_only: bool = False,
):
with torch.no_grad():
if state is None:
state = torch.zeros(self.n_layer * 5, self.n_embd, device=self.device)
for i in range(self.n_layer):
state[5 * i + 4] -= 1e30
x, state = self.model.forward(
tokens=tokens,
state=state,
preprocess_only=preprocess_only,
)
return x, state
@click.command()
@click.option(
"--float-mode",
type=click.Choice(["fp32", "fp16", "bf16"]),
default="fp16",
)
@click.option("--device", type=click.Choice(["cpu", "cuda"]))
@click.option("--opt/--no-opt", default=True, help="Optimize TorchScript model")
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path())
def convert(float_mode, device, opt, model_path, output_path):
float_dtypes = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf32": torch.bfloat16,
}
print(f"○ Loading model from {model_path}...")
print()
model = RWKVJIT(
model_path=model_path,
float_dtype=float_dtypes[float_mode],
device=torch.device(device),
)
print()
print("○ Converting model to TorchScript...")
model = torch.jit.script(model)
if float_mode == "bf16":
model = model.bfloat16()
elif float_mode == "fp16":
model = model.half()
else:
model = model.float()
if device == "cuda":
model = model.cuda()
if opt:
print("○ Optimizing model for inference...")
model = torch.jit.optimize_for_inference(model)
print(f"○ Saving model to {output_path}...")
model.save(output_path)
print("○ Done.")
print()
print("○ Note: 脚本编写时的 ChatRWKV 版本为:git+d42e9e7,其他版本可能不兼容。")
if __name__ == "__main__":
convert()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment