This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def norm(x): | |
| return F.rms_norm(x, (x.shape[-1],)) | |
| def spectral_init(module): | |
| if hasattr(module, 'weight'): | |
| nn.init.orthogonal_(module.weight) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from matplotlib import pyplot as plt | |
| from torch.distributions import Normal | |
| import math | |
| import numpy as np | |
| import torch | |
| import random | |
| from tqdm import trange | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| K_components = 8 # try 4, 8, 16 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| # Quick sketch, would need to validate correctness and numerical stabiltiym but i think its right? | |
| # N elements if you doing like patch ar or something, say 768 for a 16x16x3 patch | |
| # For each mixture component => 9 params for L, plus 1 logit => total 10. | |
| # if num_mixtures=K, each point has 10*K parameters => (B,768,10K) output. | |
| # self.fc_out = nn.Linear(hidden_dim, 10 * num_mixtures * N_Elements) | |
| def unpack_3d_logistic_params(param_9): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| class broadcast_right: | |
| def __init__(self, dim=0): | |
| self.dim = dim | |
| self._old_add = None | |
| def __enter__(self): | |
| self._old_add = torch.Tensor.__add__ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Smear(nn.Module): | |
| def __init__(self, n_heads, seq_len): | |
| super().__init__() | |
| # 73% to the first key and 27% to the second key. | |
| self.alpha_values = torch.nn.Parameter(torch.ones(1, n_heads, seq_len-1, 1)) | |
| def forward(self, k: torch.Tensor) -> torch.Tensor: | |
| # k has shape (batch_size, n_heads, seq_len, d_k) | |
| smeared_k = k[:, :, 1:, :]*(torch.sigmoid(self.alpha_values[:, :, 1:k.shape[2], :]))+ k[:, :, :-1, :]*(1-torch.sigmoid(self.alpha_values[:, :, -k.shape[2]:-1, :])) | |
| return torch.cat([k[:, :, 0:1, :], smeared_k], dim=2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class SubspaceLinear(nn.Linear): | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass for the BaseSubspaceLinear layer. Calls `subspace_weights` to sample from the subspace | |
| and uses the corresponding weight and bias. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| learning_rate = 4e-4 | |
| warmup_steps = 2000 | |
| log_step_interval = 1 | |
| eval_iters = 100 | |
| save_step_interval = 1000 | |
| eval_step_interval = 1000 | |
| weight_decay = 1e-1 | |
| beta1 = 0.9 | |
| beta2 = 0.95 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torchdiffeq import odeint | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from PIL import Image | |
| import tqdm | |
| import imageio | |
| import os | |
| # Load target image and preprocess it |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Author Adam Hibble @algomancer | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import tqdm | |
| def get_padding(padding_type, kernel_size): | |
| assert padding_type in ['SAME', 'VALID'] | |
| if padding_type == 'SAME': | |
| return tuple((k - 1) // 2 for k in kernel_size) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Install useful stuff | |
| ! apt install --yes ssh screen nano htop ranger git > /dev/null | |
| # SSH setting | |
| ! echo "root:carbonara" | chpasswd | |
| ! echo "PasswordAuthentication yes" > /etc/ssh/sshd_config | |
| ! echo "PermitUserEnvironment yes" >> /etc/ssh/sshd_config | |
| ! echo "PermitRootLogin yes" >> /etc/ssh/sshd_config | |
| ! service ssh restart > /dev/null | |
| # Download ngrok | |
| ! wget -q -c -nc https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip |
NewerOlder