https://wandb.ai/rom1504/dalle2_train_decoder/runs/mic5buox/files/decoder_config.json
get dalle2
get the config file
get these 2 .sh
run sbatch start_big.sh
| import gc | |
| from typing import Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| import triton | |
| import triton.language as tl | |
| import triton.testing | |
| from kernels import get_kernel |
| import torch | |
| def mel_filterbank( | |
| n_freqs, | |
| f_min, | |
| f_max, | |
| n_mels, | |
| sample_rate, | |
| norm=None, | |
| mel_scale="htk" |
| from __future__ import annotations | |
| from functools import reduce | |
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.utils.rnn import pad_sequence | |
| import einx |
| #include <stdio.h> | |
| // Check tensor core's warp register layout | |
| // nvcc -arch=sm_75 tensorcore_mapping.cu -o mapping | |
| // ./mapping | |
| // Define some error checking macros. | |
| #define cudaErrCheck(stat) { cudaErrCheck_((stat), __FILE__, __LINE__); } | |
| void cudaErrCheck_(cudaError_t stat, const char *file, int line) { | |
| if (stat != cudaSuccess) { |
https://wandb.ai/rom1504/dalle2_train_decoder/runs/mic5buox/files/decoder_config.json
get dalle2
get the config file
get these 2 .sh
run sbatch start_big.sh
Issue title: (working implementation) Fused multi-head attention for arbitrary sequence lengths.
TL;DR you can run multi-head attention (fwd+bwd) faster and with no extra memory – with any sequence length and head dim. We’d love to make it available via apex. We need your advice on how best to do that.
Why should I care? Here's how it compares against the standard multihead attention (blue) for one multi-head attention layer of GPT-J on an RTX 3080Ti.
| time, with backward (ms) | peak vram allocated (mb) |
|---|---|
![]() |
![]() |
| from typing import Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| import itertools | |
| from timeit import default_timer as timer | |
| class SoftmaxWeightedMean(torch.autograd.Function): | |
| @staticmethod |
| import torch | |
| import torch.utils.dlpack | |
| import jax | |
| import jax.dlpack | |
| # A generic mechanism for turning a JAX function into a PyTorch function. | |
| def j2t(x_jax): | |
| x_torch = torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x_jax)) | |
| return x_torch |
| """Complex momentum SGD and Adam. See https://arxiv.org/abs/2102.08431.""" | |
| import math | |
| import torch | |
| from torch import optim | |
| class ComplexSGD(optim.Optimizer): | |
| def __init__(self, params, lr=1e-2, momentum=0.9, angle=math.pi / 8, weight_decay=0.): |
| from functools import partial | |
| import torch | |
| def _const(example, val): | |
| return torch.tensor(val, dtype=example.dtype) | |
| def pad(x, axis, side): | |
| shape = list(x.size()) | |
| if axis == -1: | |
| axis = len(shape) - 1 |