Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save MeetThePatel/155f2e649cd80016c3c9a1c3882bf6bd to your computer and use it in GitHub Desktop.

Select an option

Save MeetThePatel/155f2e649cd80016c3c9a1c3882bf6bd to your computer and use it in GitHub Desktop.
Testing GALA (adapted to Muon and Adam) for NanoGPT speedrun
import os
import subprocess
import sys
from typing import Tuple
with open(sys.argv[0]) as f:
code = f.read()
import copy
import glob
import math
import time
import uuid
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.attention.flex_attention import BlockMask, flex_attention
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.optimizer import ParamsT
# ----------------------------------------------------------------------------------------------------------------------------------------------------
# Custom operators: FP8 matmul by @YouJiacheng
# ----------------------------------------------------------------------------------------------------------------------------------------------------
@torch.library.custom_op("nanogpt::mm", mutates_args=())
def mm_op(
x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float
) -> tuple[Tensor, Tensor, Tensor]:
@torch.compile
def impl(x: Tensor, w: Tensor):
assert x.is_contiguous() and w.is_contiguous()
x_f8 = x.div(x_s).to(torch.float8_e4m3fn)
w_f8 = w.div(w_s).to(torch.float8_e4m3fn)
out = torch._scaled_mm(
x_f8,
w_f8.T,
out_dtype=torch.bfloat16,
scale_a=x.new_tensor(x_s, dtype=torch.float32),
scale_b=x.new_tensor(w_s, dtype=torch.float32),
use_fast_accum=True,
)
return out, x_f8, w_f8
return impl(x, w)
@mm_op.register_fake
def _(x: Tensor, w: Tensor, *_):
assert x.ndim == w.ndim == 2
assert x.shape[1] == w.shape[1]
assert x.device == w.device
assert x.is_contiguous() and w.is_contiguous()
return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn)
@torch.library.custom_op("nanogpt::mm_backward", mutates_args=())
def mm_backward_op(
g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float
) -> tuple[Tensor, Tensor]:
@torch.compile
def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor):
assert grad.is_contiguous()
x_inv_s = grad.new_tensor(x_s, dtype=torch.float32)
w_inv_s = grad.new_tensor(w_s, dtype=torch.float32)
grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32)
grad_f8 = grad.div(grad_s).to(torch.float8_e5m2)
grad_x = torch._scaled_mm(
grad_f8,
w_f8.T.contiguous().T,
out_dtype=torch.bfloat16,
scale_a=grad_inv_s,
scale_b=w_inv_s,
use_fast_accum=False,
)
# faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768)
grad_w = torch._scaled_mm(
x_f8.T.contiguous(),
grad_f8.T.contiguous().T,
out_dtype=torch.float32,
scale_a=x_inv_s,
scale_b=grad_inv_s,
use_fast_accum=False,
).T
return grad_x, grad_w
return impl(g, x_f8, w_f8)
@mm_backward_op.register_fake
def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_):
return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32)
def backward(ctx, grad_out: Tensor, *_):
x_f8, w_f8 = ctx.saved_tensors
x_s, w_s, grad_s = ctx.scales
grad_x, grad_w = torch.ops.nanogpt.mm_backward(
grad_out, x_f8, w_f8, x_s, w_s, grad_s
)
return grad_x, grad_w, None, None, None
def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output):
*_, x_s, w_s, grad_s = inputs
_, x_f8, w_f8 = output
ctx.save_for_backward(x_f8, w_f8)
ctx.scales = x_s, w_s, grad_s
ctx.set_materialize_grads(False)
mm_op.register_autograd(backward, setup_context=setup_context)
# ----------------------------------------------------------------------------------------------------------------------------------------------------
# Optimizers
# ----------------------------------------------------------------------------------------------------------------------------------------------------
class GalaAdam(torch.optim.Adam):
def __init__(
self,
params: ParamsT,
lr: float = 0.003,
betas: Tuple[float, float] = (0.9, 0.999),
eps=1e-8,
weight_decay=0.0,
gala_delta=1e-8,
max_lr=None,
):
super().__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self.gala_delta = gala_delta
self.max_lr = max_lr
self._gala_saved_info = None
for group in self.param_groups:
device = group["params"][0].device
group.setdefault("sum_alignment", torch.tensor(0.0, device=device))
group.setdefault("sum_reg", torch.tensor(0.0, device=device))
@torch.no_grad()
def step(self):
super().step()
@torch.no_grad()
def gala_phase1(self):
saved_info = []
for group in self.param_groups:
lr = group["lr"]
for p in group["params"]:
if p.grad is None:
grad = torch.zeros_like(p)
else:
grad = p.grad.detach().clone()
saved_info.append({"grad_t": grad, "lr": lr})
super().step()
self._gala_saved_info = saved_info
@torch.no_grad()
def gala_phase2(self):
if self._gala_saved_info is None:
raise RuntimeError("pre_step() must be called first.")
saved_info = self._gala_saved_info
idx = 0
for group in self.param_groups:
sum_align = group.get("sum_alignment", 0.0)
sum_reg = group.get("sum_reg", 0.0)
beta1, beta2 = group["betas"]
eps = group["eps"]
for p in group["params"]:
grad_t = saved_info[idx]["grad_t"]
grad_tp1 = (
p.grad.detach().clone()
if p.grad is not None
else torch.zeros_like(p)
)
lr = saved_info[idx]["lr"]
state = self.state[p]
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
step_num = state["step"]
bias_correction_1 = 1 - beta1**step_num
bias_correction_2 = 1 - beta2**step_num
exp_avg_corr = exp_avg / bias_correction_1
exp_avg_sq_corr = exp_avg_sq / bias_correction_2
denom = exp_avg_sq_corr.sqrt().add_(eps)
d_t = exp_avg_corr / denom
align = torch.sum(d_t * grad_t)
diff = grad_tp1 - grad_t
curv = (diff.norm() * grad_t.norm()) + (lr + self.gala_delta)
sum_align += align
sum_reg += curv
idx += 1
group["sum_alignment"] = sum_align
group["sum_reg"] = sum_reg
# TODO: Tensor-ify LR. Running into insane amounts of graphs breaks. See if can fix this later.
new_lr = (sum_align / (sum_reg + self.gala_delta)).item()
if self.max_lr is not None:
new_lr = max(min(new_lr, self.max_lr), 0.0)
group["lr"] = new_lr
self._gala_saved_info = None
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert (
G.ndim >= 2
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = (
b * A + c * A @ A
) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
class Muon(torch.optim.Optimizer):
def __init__(
self,
params,
lr=0.02,
momentum=0.95,
nesterov=True,
ns_steps=5,
):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
lr = group["lr"]
momentum = group["momentum"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
for p in group["params"]:
grad: Tensor = p.grad
if grad is None:
continue
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(grad)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(grad, 1 - momentum)
grad = grad.lerp_(buf, momentum) if nesterov else buf
original_shape = grad.shape
if grad.ndim > 2:
grad = grad.flatten(1)
grad = zeropower_via_newtonschulz5(grad, steps=ns_steps)
grad = grad.view(original_shape)
else:
grad = zeropower_via_newtonschulz5(grad, steps=ns_steps)
scale = -lr * math.sqrt(max(1.0, p.size(-2) / p.size(-1)))
p.add_(grad, alpha=scale)
class GalaMuon(Muon):
def __init__(
self,
params,
lr: float,
momentum: float = 0.95,
nesterov: bool = True,
ns_steps: int = 5,
gala_delta=1e-8,
max_lr=None,
):
super().__init__(
params, lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps
)
self.gala_delta = gala_delta
self.max_lr = max_lr
self._gala_saved_info = None
for group in self.param_groups:
device = group["params"][0].device
group.setdefault("sum_alignment", torch.tensor(0.0, device=device))
group.setdefault("sum_reg", torch.tensor(0.0, device=device))
@torch.no_grad()
def step(self):
super().step()
@torch.no_grad()
def gala_phase1(self):
saved_info = []
for group in self.param_groups:
lr = group["lr"]
for p in group["params"]:
if p.grad is None:
grad = torch.zeros_like(p)
else:
grad = p.grad.detach().clone()
saved_info.append({"grad_t": grad, "lr": lr})
super().step()
self._gala_saved_info = saved_info
@torch.no_grad()
def gala_phase2(self):
if self._gala_saved_info is None:
raise RuntimeError("pre_step() must be called first.")
saved_info = self._gala_saved_info
idx = 0
for group in self.param_groups:
sum_align = group.get("sum_alignment", 0.0)
sum_reg = group.get("sum_reg", 0.0)
momentum = group["momentum"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
for p in group["params"]:
grad_t = saved_info[idx]["grad_t"]
grad_tp1 = (
p.grad.detach().clone()
if p.grad is not None
else torch.zeros_like(p)
)
lr = saved_info[idx]["lr"]
state = self.state[p]
if "momentum_buffer" not in state:
# TODO: Can I remove this check? Since this is phase2, momentum buffer is guaranteed to have bene filled.
state["momentum_buffer"] = torch.zeros_like(grad_tp1)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(grad_tp1, 1 - momentum)
grad_tp1 = grad_tp1.lerp_(buf, momentum) if nesterov else buf
original_shape = grad_tp1.shape
if grad_tp1.ndim > 2:
grad_tp1 = grad_tp1.flatten(1)
grad_tp1 = zeropower_via_newtonschulz5(grad_tp1, steps=ns_steps)
grad_tp1 = grad_tp1.view(original_shape)
else:
grad_tp1 = zeropower_via_newtonschulz5(grad_tp1, steps=ns_steps)
align = torch.sum(grad_tp1 * grad_t) # grad_tp1 <-- post-Muon direction
diff = grad_tp1 - grad_t
curv = (diff.norm() * grad_t.norm()) + (lr + self.gala_delta)
sum_align += align
sum_reg += curv
idx += 1
group["sum_alignment"] = sum_align
group["sum_reg"] = sum_reg
new_lr = (sum_align / (sum_reg + self.gala_delta)).item()
if self.max_lr is not None:
new_lr = max(min(new_lr, self.max_lr), 0.0)
group["lr"] = new_lr
self._gala_saved_info = None
# ----------------------------------------------------------------------------------------------------------------------------------------------------
# Model Definition
# ----------------------------------------------------------------------------------------------------------------------------------------------------
def rms_norm(x: Tensor):
return F.rms_norm(x, (x.size(-1),))
def next_multiple_of_n(v: float | int, *, n: int):
return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)
class CastedLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
use_fp8=False,
x_s=1.0,
w_s=1.0,
grad_s=1.0,
):
super().__init__(in_features, out_features, bias=False)
self.use_fp8 = use_fp8
self.x_s = x_s
self.w_s = w_s
self.grad_s = grad_s
def reset_parameters(self) -> None:
std = 0.5 * (
self.in_features**-0.5
) # 0.5 is a bit better than the default 1/sqrt(3)
bound = (3**0.5) * std
with torch.no_grad():
self.weight.uniform_(-bound, bound)
def forward(self, x: Tensor):
if self.use_fp8 and self.training:
_x = x.flatten(0, -2)
out: Tensor = torch.ops.nanogpt.mm(
_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s
)[0]
return out.reshape(*x.shape[:-1], -1)
else:
return F.linear(x, self.weight.type_as(x))
class Rotary(nn.Module):
def __init__(self, dim: int, max_seq_len: int):
super().__init__()
# half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
angular_freq = (1 / 1024) ** torch.linspace(
0, 1, steps=dim // 4, dtype=torch.float32
)
angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim // 4)])
t = torch.arange(max_seq_len, dtype=torch.float32)
theta = torch.einsum("i,j -> ij", t, angular_freq)
self.cos = nn.Buffer(theta.cos(), persistent=False)
self.sin = nn.Buffer(theta.sin(), persistent=False)
def forward(self, x_BTHD: Tensor):
assert self.cos.size(0) >= x_BTHD.size(-3)
cos, sin = (
self.cos[None, : x_BTHD.size(-3), None, :],
self.sin[None, : x_BTHD.size(-3), None, :],
)
x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3).type_as(x_BTHD)
class CausalSelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
hdim = num_heads * head_dim
std = 0.5 * (dim**-0.5)
bound = (3**0.5) * std # improved init scale by @YouJiacheng
# merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng
# https://x.com/hi_tysam/status/1879699187107033311
self.qkv_w = nn.Parameter(torch.empty(3, hdim, dim).uniform_(-bound, bound))
self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5]))
self.rotary = Rotary(head_dim, max_seq_len)
self.c_proj = CastedLinear(hdim, dim)
self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977
def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask):
B, T = x.size(0), x.size(1) # batch size, sequence length
assert B == 1, "Must use batch size = 1 for FlexAttention"
q, k, v = (
F.linear(x, self.qkv_w.flatten(end_dim=1).type_as(x))
.view(B, T, 3 * self.num_heads, self.head_dim)
.chunk(3, dim=-2)
)
q, k = rms_norm(q), rms_norm(k) # QK norm @Grad62304977
q, k = self.rotary(q), self.rotary(k)
if ve is not None:
v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(
v
) # @KoszarskyB & @Grad62304977
else: # skip mid-layers token value embeddings by @YouJiacheng
v = self.lambdas[0] * v
# scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun
# inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283
y = flex_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
block_mask=block_mask,
scale=0.12,
).transpose(1, 2)
y = y.contiguous().view(
B, T, self.num_heads * self.head_dim
) # re-assemble all head outputs side by side
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, dim: int):
super().__init__()
hdim = 4 * dim
self.c_fc = CastedLinear(dim, hdim)
self.c_proj = CastedLinear(hdim, dim)
self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977
def forward(self, x: Tensor):
x = self.c_fc(x)
x = F.relu(
x
).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int):
super().__init__()
# skip attention of blocks.7 (the 8th layer) by @YouJiacheng
self.attn = (
CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None
)
self.mlp = MLP(dim)
self.lambdas = nn.Parameter(torch.tensor([1.0, 0.0]))
def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask):
x = self.lambdas[0] * x + self.lambdas[1] * x0
if self.attn is not None:
x = x + self.attn(rms_norm(x), ve, block_mask)
x = x + self.mlp(rms_norm(x))
return x
class GPT(nn.Module):
def __init__(
self,
vocab_size: int,
num_layers: int,
num_heads: int,
model_dim: int,
max_seq_len: int,
):
super().__init__()
self.embed = nn.Embedding(vocab_size, model_dim)
# token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897
# value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78
self.value_embeds = nn.ModuleList(
[nn.Embedding(vocab_size, model_dim) for _ in range(3)]
)
self.blocks = nn.ModuleList(
[Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]
)
# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency.
# suggested to me by @Grad62304977. this originates from Karpathy's experiments.
self.lm_head = CastedLinear(
model_dim,
next_multiple_of_n(vocab_size, n=128),
use_fp8=True,
x_s=(model_dim**0.5) / 448,
w_s=24 / 448,
grad_s=1 / 448,
)
self.lm_head.weight.detach().zero_() # @Grad62304977
# Add learnable skip connection weights for decoder layers
assert num_layers % 2 == 0
self.skip_weights = nn.Parameter(torch.ones(num_layers // 2))
def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor):
BLOCK_SIZE = 128
docs = (input_seq == 50256).cumsum(0)
def document_causal(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = docs[q_idx] == docs[kv_idx]
return causal_mask & document_mask
def dense_to_ordered(dense_blockmask: Tensor):
num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32)
indices = (
dense_blockmask.argsort(dim=-1, descending=False, stable=True)
.flip(-1)
.to(torch.int32)
)
return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
# manual block mask creation by @YouJiacheng
assert len(input_seq) % BLOCK_SIZE == 0
NUM_BLOCKS = len(input_seq) // BLOCK_SIZE
block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda")
causal_blockmask_any = block_idx[:, None] >= block_idx
causal_blockmask_all = block_idx[:, None] > block_idx
docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous()
docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous()
document_blockmask_any = (docs_low[:, None] <= docs_high) & (
docs_high[:, None] >= docs_low
)
document_blockmask_all = (docs_low[:, None] == docs_high) & (
docs_high[:, None] == docs_low
)
blockmask_any = causal_blockmask_any & document_blockmask_any
blockmask_all = causal_blockmask_all & document_blockmask_all
partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(
blockmask_any & ~blockmask_all
)
full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all)
def build_bm(window_size_blocks: Tensor) -> BlockMask:
return BlockMask.from_kv_blocks(
torch.clamp_max(
partial_kv_num_blocks,
torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1),
),
partial_kv_indices,
torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1),
full_kv_indices,
BLOCK_SIZE=BLOCK_SIZE,
mask_mod=document_causal,
)
# Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper
return build_bm(sliding_window_num_blocks), build_bm(
sliding_window_num_blocks // 2
)
def forward(
self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor
):
assert input_seq.ndim == 1
ve = [value_embed(input_seq) for value_embed in self.value_embeds]
# 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure
ve = (
[ve[0], ve[1], ve[2]]
+ [None] * (len(self.blocks) - 6)
+ [ve[0], ve[1], ve[2]]
)
assert len(ve) == len(self.blocks)
long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks)
block_masks = [
long_bm,
short_bm,
short_bm,
short_bm,
long_bm,
short_bm,
short_bm,
long_bm,
short_bm,
short_bm,
short_bm,
long_bm,
]
assert len(block_masks) == len(self.blocks)
x = x0 = rms_norm(
self.embed(input_seq)[None]
) # use of norm here by @Grad62304977
# U-net design by @brendanh0gan
skip_connections = []
n = len(self.skip_weights)
for i in range(len(self.blocks)):
if i >= n:
x = x + self.skip_weights[i - n] * skip_connections.pop()
x = self.blocks[i](x, ve[i], x0, block_masks[i])
if i < n:
skip_connections.append(x)
x = rms_norm(x)
logits = self.lm_head(x).float()
# @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1)
logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1) ** 0.5))
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
target_seq,
reduction="sum" if self.training else "mean",
)
return loss
# attention window size schedule: linearly increase
@lru_cache(1)
def get_window_size_blocks_helper(window_size: int):
return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(
non_blocking=True
)
def get_window_size_blocks(step: int, total_iterations: int):
x = step / total_iterations # progress in training
assert 0 <= x <= 1
# Linearly increase the block-wise sliding window size over training 128 -> 1792
# increase by @fernbear.bsky.social; block-wise by @YouJiacheng
window_size = next_multiple_of_n(1728 * x, n=128)
return get_window_size_blocks_helper(window_size)
# ----------------------------------------------------------------------------------------------------------------------------------------------------
# Data Loader
# ----------------------------------------------------------------------------------------------------------------------------------------------------
def _load_data_shard(file: Path):
header = torch.from_file(
str(file), False, 256, dtype=torch.int32
) # header is 256 int32
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
assert header[1] == 1, "unsupported version"
num_tokens = int(header[2]) # number of tokens (claimed)
with file.open("rb", buffering=0) as f:
tokens = torch.empty(
num_tokens, dtype=torch.uint16, pin_memory=True
) # avoid pin_memory copy by @YouJiacheng
f.seek(256 * 4)
nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng
assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
return tokens
def distributed_data_generator(
filename_pattern: str, batch_size: int, rank: int, world_size: int
):
files = [Path(file) for file in sorted(glob.glob(filename_pattern))]
assert batch_size % world_size == 0
local_batch_size = batch_size // world_size
file_iter = iter(
files
) # use itertools.cycle(files) instead if you want to do multi-epoch training
tokens, pos = _load_data_shard(next(file_iter)), 0
while True:
if pos + batch_size + 1 >= len(tokens):
tokens, pos = _load_data_shard(next(file_iter)), 0
buf = tokens[pos + rank * local_batch_size :][: local_batch_size + 1]
inputs = buf[:-1].to(
device="cuda", dtype=torch.int32, non_blocking=True
) # no sync on host side;
targets = buf[1:].to(
device="cuda", dtype=torch.int64, non_blocking=True
) # H2D in another stream isn't helpful.
pos += batch_size
yield inputs, targets
# ----------------------------------------------------------------------------------------------------------------------------------------------------
# Preamble
# ----------------------------------------------------------------------------------------------------------------------------------------------------
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert world_size == 1
assert torch.cuda.is_available()
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
master_process = rank == 0
logfile = None
if master_process:
run_id = uuid.uuid4()
os.makedirs("logs", exist_ok=True)
logfile = f"logs/{run_id}.txt"
print(logfile)
def print0(s, console=False):
if master_process:
with open(logfile, "a") as f:
if console:
print(s)
print(s, file=f)
print0(code)
print0("=" * 100)
# Environment information
print0(f"Running Python {sys.version}")
print0(
f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}"
)
print0(
subprocess.run(
["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
).stdout
)
print0("=" * 100)
# ----------------------------------------------------------------------------------------------------------------------------------------------------
# Construct model and optimizers
# ----------------------------------------------------------------------------------------------------------------------------------------------------
@dataclass
class Hyperparameters:
# data
train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on
val_files = (
"data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on
)
val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
train_seq_len = 48 * 1024 # FlexAttention sequence length
val_seq_len = 4 * 64 * 1024 # FlexAttention sequence length for validation
# optimization
num_iterations = 1770 # number of iterations to run
cooldown_frac = 0.4 # fraction of training spent cooling down the learning rate
# architecture
vocab_size = 50257
# evaluation and logging
val_loss_every = (
125 # every how many steps to evaluate val loss? 0 for only at the end
)
gala_update_freq = (
150 # number of steps between GALA updates. Same for both optimizers.
)
save_checkpoint = False
args = Hyperparameters()
model: nn.Module = GPT(
vocab_size=args.vocab_size,
num_layers=12,
num_heads=6,
model_dim=768,
max_seq_len=max(args.train_seq_len, args.val_seq_len),
).to(device)
for m in model.modules():
if isinstance(m, nn.Embedding):
m.bfloat16()
model = DDP(model, device_ids=[device.index])
hidden_matrix_params = [
p
for n, p in model.module.blocks.named_parameters()
if p.ndim >= 2 and "embed" not in n
]
embed_params = [p for n, p in model.module.named_parameters() if "embed" in n]
scalar_params = [p for p in model.module.parameters() if p.ndim < 2]
head_params = [model.module.lm_head.weight]
adam_params = [
dict(params=head_params, lr=0.22),
dict(params=embed_params, lr=0.6),
dict(params=scalar_params, lr=0.04),
]
optimizer1 = GalaAdam(adam_params, betas=(0.8, 0.95), eps=1e-8, weight_decay=0.0)
optimizer2 = GalaMuon(hidden_matrix_params, lr=0.05, momentum=0.95)
optimizers = [optimizer1, optimizer2]
for opt in optimizers:
for group in opt.param_groups:
group["initial_lr"] = group["lr"]
# ----------------------------------------------------------------------------------------------------------------------------------------------------
# Train Loop
# ----------------------------------------------------------------------------------------------------------------------------------------------------
def default_train_step(
model: nn.Module,
optimizers: list[torch.optim.Optimizer],
inputs,
targets,
sliding_window_num_blocks: Tensor,
):
model(inputs, targets, sliding_window_num_blocks).backward()
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
compiled_default_train_step = torch.compile(default_train_step, dynamic=False)
def gala_train_step(
model: nn.Module,
optimizers: list[torch.optim.Optimizer],
inputs,
targets,
sliding_window_num_blocks: Tensor,
):
model(inputs, targets, sliding_window_num_blocks).backward()
for opt in optimizers:
opt.gala_phase1()
model.zero_grad(set_to_none=True)
model(inputs, targets, sliding_window_num_blocks).backward()
for opt in optimizers:
opt.gala_phase2()
model.zero_grad(set_to_none=True)
compiled_gala_train_step = torch.compile(gala_train_step, dynamic=False)
def val_step(model, inputs, targets, sliding_window_num_blocks: Tensor):
with torch.no_grad():
loss = model(inputs, targets, sliding_window_num_blocks)
return loss
compiled_val_step = torch.compile(val_step, dynamic=False)
# Warmup kernels.
initial_state = dict(
model=copy.deepcopy(model.state_dict()),
optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers],
)
for _ in range(5):
inputs = targets = torch.randint(
0, args.vocab_size, size=(args.train_seq_len,), device="cuda"
)
compiled_default_train_step(
model,
optimizers,
inputs,
targets,
get_window_size_blocks(0, args.num_iterations),
)
compiled_gala_train_step(
model,
optimizers,
inputs,
targets,
get_window_size_blocks(0, args.num_iterations),
)
compiled_val_step(
model,
inputs,
targets,
get_window_size_blocks(0, args.num_iterations),
)
model.load_state_dict(initial_state["model"])
for opt, opt_state in zip(optimizers, initial_state["optimizers"]):
opt.load_state_dict(opt_state)
del initial_state
train_loader = distributed_data_generator(
args.train_files, world_size * args.train_seq_len, rank, world_size
)
training_time_ms = 0
torch.cuda.synchronize()
t0 = time.perf_counter()
for step in range(args.num_iterations + 1):
last_step = step == args.num_iterations
if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
torch.cuda.synchronize()
training_time_ms += 1000 * (time.perf_counter() - t0)
model.eval()
val_loader = distributed_data_generator(
args.val_files, world_size * args.val_seq_len, rank, world_size
)
val_loss = 0
with torch.no_grad():
for _ in range(args.val_tokens // (world_size * args.val_seq_len)):
inputs, targets = next(val_loader)
val_loss += compiled_val_step(
model,
inputs,
targets,
get_window_size_blocks(step, args.num_iterations),
)
val_loss /= args.val_tokens // (world_size * args.val_seq_len)
dist.all_reduce(val_loss, dist.ReduceOp.AVG)
del val_loader
print0(
f"step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms",
console=True,
)
model.train()
torch.cuda.synchronize()
t0 = time.perf_counter()
if last_step:
if master_process and args.save_checkpoint:
log = dict(
step=step,
code=code,
model=model.state_dict(),
optimizers=[opt.state_dict() for opt in optimizers],
)
os.makedirs(f"logs/{run_id}", exist_ok=True)
torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt")
break
for group in optimizer2.param_groups:
frac = min(step / 300, 1)
group["momentum"] = (1 - frac) * 0.85 + frac * 0.95
inputs, targets = next(train_loader)
if (step + 1) % args.gala_update_freq == 0:
compiled_gala_train_step(
model,
optimizers,
inputs,
targets,
get_window_size_blocks(step, args.num_iterations),
)
else:
compiled_default_train_step(
model,
optimizers,
inputs,
targets,
get_window_size_blocks(step, args.num_iterations),
)
if (step + 1) % args.gala_update_freq == 0:
if master_process:
# For GalaAdam (optimizer1) which has multiple parameter groups
adam_lrs = [f"{g['lr']:.6f}" for g in optimizer1.param_groups]
# For GalaMuon (optimizer2) which has one parameter group
muon_lrs = [f"{g['lr']:.6f}" for g in optimizer2.param_groups]
log_msg = (
f"GALA update on step {step + 1}: "
f"GalaAdam LRs=[{', '.join(adam_lrs)}], "
f"GalaMuon LRs=[{', '.join(muon_lrs)}]"
)
print0(log_msg, console=True)
approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
print0(
f"step:{step + 1}/{args.num_iterations} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms",
console=True,
)
# ----------------------------------------------------------------------------------------------------------------------------------------------------
# Epilogue
# ----------------------------------------------------------------------------------------------------------------------------------------------------
print0(
f"peak memory allocated: {torch.cuda.max_memory_allocated() // (1024**2)} MiB reserved: {torch.cuda.max_memory_reserved() // (1024**2)} MiB",
console=True,
)
dist.destroy_process_group()
====================================================================================================
Running Python 3.12.3 (main, Feb 4 2025, 14:48:35) [GCC 13.3.0]
Running PyTorch 2.7.1+cu126 compiled for CUDA 12.6
Fri Jun 20 02:07:30 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:C6:00.0 Off | 0 |
| N/A 40C P0 141W / 700W | 1184MiB / 81559MiB | 4% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+
====================================================================================================
step:0/1770 val_loss:10.8258 train_time:0ms step_avg:0.04ms
step:1/1770 train_time:22227ms step_avg:22226.87ms
step:2/1770 train_time:44293ms step_avg:22146.66ms
step:3/1770 train_time:44392ms step_avg:14797.18ms
step:4/1770 train_time:44493ms step_avg:11123.24ms
step:5/1770 train_time:44595ms step_avg:8919.06ms
step:6/1770 train_time:44697ms step_avg:7449.55ms
step:7/1770 train_time:44800ms step_avg:6399.98ms
step:8/1770 train_time:44902ms step_avg:5612.79ms
step:9/1770 train_time:45005ms step_avg:5000.52ms
step:10/1770 train_time:45107ms step_avg:4510.71ms
step:11/1770 train_time:45209ms step_avg:4109.95ms
step:12/1770 train_time:45312ms step_avg:3776.03ms
step:13/1770 train_time:45414ms step_avg:3493.41ms
step:14/1770 train_time:45516ms step_avg:3251.18ms
step:15/1770 train_time:45619ms step_avg:3041.28ms
step:16/1770 train_time:45721ms step_avg:2857.57ms
step:17/1770 train_time:45823ms step_avg:2695.49ms
step:18/1770 train_time:45925ms step_avg:2551.41ms
step:19/1770 train_time:46028ms step_avg:2422.51ms
step:20/1770 train_time:46130ms step_avg:2306.51ms
step:21/1770 train_time:46233ms step_avg:2201.56ms
step:22/1770 train_time:46335ms step_avg:2106.13ms
step:23/1770 train_time:46437ms step_avg:2019.00ms
step:24/1770 train_time:46540ms step_avg:1939.15ms
step:25/1770 train_time:46642ms step_avg:1865.68ms
step:26/1770 train_time:46744ms step_avg:1797.85ms
step:27/1770 train_time:46848ms step_avg:1735.09ms
step:28/1770 train_time:46950ms step_avg:1676.80ms
step:29/1770 train_time:47053ms step_avg:1622.53ms
step:30/1770 train_time:47156ms step_avg:1571.86ms
step:31/1770 train_time:47258ms step_avg:1524.45ms
step:32/1770 train_time:47360ms step_avg:1480.01ms
step:33/1770 train_time:47462ms step_avg:1438.25ms
step:34/1770 train_time:47565ms step_avg:1398.96ms
step:35/1770 train_time:47667ms step_avg:1361.91ms
step:36/1770 train_time:47769ms step_avg:1326.93ms
step:37/1770 train_time:47872ms step_avg:1293.84ms
step:38/1770 train_time:47975ms step_avg:1262.51ms
step:39/1770 train_time:48077ms step_avg:1232.74ms
step:40/1770 train_time:48180ms step_avg:1204.50ms
step:41/1770 train_time:48282ms step_avg:1177.60ms
step:42/1770 train_time:48384ms step_avg:1152.01ms
step:43/1770 train_time:48487ms step_avg:1127.61ms
step:44/1770 train_time:48590ms step_avg:1104.32ms
step:45/1770 train_time:48693ms step_avg:1082.06ms
step:46/1770 train_time:48795ms step_avg:1060.76ms
step:47/1770 train_time:48898ms step_avg:1040.38ms
step:48/1770 train_time:49001ms step_avg:1020.85ms
step:49/1770 train_time:49103ms step_avg:1002.11ms
step:50/1770 train_time:49206ms step_avg:984.12ms
step:51/1770 train_time:49309ms step_avg:966.85ms
step:52/1770 train_time:49412ms step_avg:950.22ms
step:53/1770 train_time:49514ms step_avg:934.22ms
step:54/1770 train_time:49617ms step_avg:918.83ms
step:55/1770 train_time:49719ms step_avg:903.99ms
step:56/1770 train_time:49822ms step_avg:889.67ms
step:57/1770 train_time:49926ms step_avg:875.90ms
step:58/1770 train_time:50026ms step_avg:862.52ms
step:59/1770 train_time:50129ms step_avg:849.64ms
step:60/1770 train_time:50231ms step_avg:837.18ms
step:61/1770 train_time:50333ms step_avg:825.14ms
step:62/1770 train_time:50436ms step_avg:813.49ms
step:63/1770 train_time:50539ms step_avg:802.20ms
step:64/1770 train_time:50641ms step_avg:791.27ms
step:65/1770 train_time:50744ms step_avg:780.68ms
step:66/1770 train_time:50846ms step_avg:770.40ms
step:67/1770 train_time:50949ms step_avg:760.43ms
step:68/1770 train_time:51052ms step_avg:750.76ms
step:69/1770 train_time:51157ms step_avg:741.41ms
step:70/1770 train_time:51257ms step_avg:732.24ms
step:71/1770 train_time:51359ms step_avg:723.37ms
step:72/1770 train_time:51461ms step_avg:714.74ms
step:73/1770 train_time:51564ms step_avg:706.35ms
step:74/1770 train_time:51667ms step_avg:698.20ms
step:75/1770 train_time:51769ms step_avg:690.26ms
step:76/1770 train_time:51872ms step_avg:682.53ms
step:77/1770 train_time:51975ms step_avg:675.00ms
step:78/1770 train_time:52078ms step_avg:667.66ms
step:79/1770 train_time:52181ms step_avg:660.52ms
step:80/1770 train_time:52284ms step_avg:653.55ms
step:81/1770 train_time:52388ms step_avg:646.76ms
step:82/1770 train_time:52491ms step_avg:640.13ms
step:83/1770 train_time:52594ms step_avg:633.66ms
step:84/1770 train_time:52696ms step_avg:627.33ms
step:85/1770 train_time:52799ms step_avg:621.16ms
step:86/1770 train_time:52901ms step_avg:615.13ms
step:87/1770 train_time:53003ms step_avg:609.23ms
step:88/1770 train_time:53106ms step_avg:603.48ms
step:89/1770 train_time:53209ms step_avg:597.86ms
step:90/1770 train_time:53312ms step_avg:592.36ms
step:91/1770 train_time:53415ms step_avg:586.98ms
step:92/1770 train_time:53518ms step_avg:581.71ms
step:93/1770 train_time:53621ms step_avg:576.56ms
step:94/1770 train_time:53724ms step_avg:571.53ms
step:95/1770 train_time:53827ms step_avg:566.60ms
step:96/1770 train_time:53930ms step_avg:561.77ms
step:97/1770 train_time:54033ms step_avg:557.04ms
step:98/1770 train_time:54135ms step_avg:552.40ms
step:99/1770 train_time:54237ms step_avg:547.85ms
step:100/1770 train_time:54339ms step_avg:543.39ms
step:101/1770 train_time:54441ms step_avg:539.02ms
step:102/1770 train_time:54544ms step_avg:534.74ms
step:103/1770 train_time:54647ms step_avg:530.55ms
step:104/1770 train_time:54750ms step_avg:526.44ms
step:105/1770 train_time:54853ms step_avg:522.40ms
step:106/1770 train_time:54955ms step_avg:518.45ms
step:107/1770 train_time:55058ms step_avg:514.56ms
step:108/1770 train_time:55161ms step_avg:510.75ms
step:109/1770 train_time:55263ms step_avg:507.00ms
step:110/1770 train_time:55366ms step_avg:503.33ms
step:111/1770 train_time:55469ms step_avg:499.72ms
step:112/1770 train_time:55572ms step_avg:496.18ms
step:113/1770 train_time:55675ms step_avg:492.70ms
step:114/1770 train_time:55777ms step_avg:489.28ms
step:115/1770 train_time:55881ms step_avg:485.92ms
step:116/1770 train_time:55983ms step_avg:482.61ms
step:117/1770 train_time:56086ms step_avg:479.37ms
step:118/1770 train_time:56189ms step_avg:476.18ms
step:119/1770 train_time:56292ms step_avg:473.04ms
step:120/1770 train_time:56399ms step_avg:469.99ms
step:121/1770 train_time:56498ms step_avg:466.93ms
step:122/1770 train_time:56601ms step_avg:463.94ms
step:123/1770 train_time:56703ms step_avg:461.00ms
step:124/1770 train_time:56806ms step_avg:458.11ms
step:125/1770 train_time:56909ms step_avg:455.27ms
step:125/1770 val_loss:5.6197 train_time:56913ms step_avg:455.30ms
step:126/1770 train_time:57015ms step_avg:452.50ms
step:127/1770 train_time:57117ms step_avg:449.74ms
step:128/1770 train_time:57221ms step_avg:447.04ms
step:129/1770 train_time:57323ms step_avg:444.36ms
step:130/1770 train_time:57426ms step_avg:441.74ms
step:131/1770 train_time:57529ms step_avg:439.15ms
step:132/1770 train_time:57631ms step_avg:436.60ms
step:133/1770 train_time:57735ms step_avg:434.10ms
step:134/1770 train_time:57838ms step_avg:431.63ms
step:135/1770 train_time:57942ms step_avg:429.20ms
step:136/1770 train_time:58045ms step_avg:426.80ms
step:137/1770 train_time:58149ms step_avg:424.44ms
step:138/1770 train_time:58253ms step_avg:422.12ms
step:139/1770 train_time:58357ms step_avg:419.83ms
step:140/1770 train_time:58460ms step_avg:417.57ms
step:141/1770 train_time:58564ms step_avg:415.34ms
step:142/1770 train_time:58667ms step_avg:413.15ms
step:143/1770 train_time:58770ms step_avg:410.98ms
step:144/1770 train_time:58874ms step_avg:408.85ms
step:145/1770 train_time:58978ms step_avg:406.74ms
step:146/1770 train_time:59081ms step_avg:404.66ms
step:147/1770 train_time:59184ms step_avg:402.61ms
step:148/1770 train_time:59287ms step_avg:400.59ms
step:149/1770 train_time:59391ms step_avg:398.59ms
GALA update on step 150: GalaAdam LRs=[0.206778, 18.967335, 0.000079], GalaMuon LRs=[0.001359]
step:150/1770 train_time:94291ms step_avg:628.61ms
step:151/1770 train_time:94422ms step_avg:625.31ms
step:152/1770 train_time:94525ms step_avg:621.88ms
step:153/1770 train_time:94632ms step_avg:618.51ms
step:154/1770 train_time:94738ms step_avg:615.18ms
step:155/1770 train_time:94845ms step_avg:611.90ms
step:156/1770 train_time:94952ms step_avg:608.66ms
step:157/1770 train_time:95058ms step_avg:605.47ms
step:158/1770 train_time:95165ms step_avg:602.31ms
step:159/1770 train_time:95272ms step_avg:599.20ms
step:160/1770 train_time:95379ms step_avg:596.12ms
step:161/1770 train_time:95487ms step_avg:593.08ms
step:162/1770 train_time:95594ms step_avg:590.08ms
step:163/1770 train_time:95700ms step_avg:587.12ms
step:164/1770 train_time:95807ms step_avg:584.19ms
step:165/1770 train_time:95914ms step_avg:581.30ms
step:166/1770 train_time:96021ms step_avg:578.44ms
step:167/1770 train_time:96128ms step_avg:575.62ms
step:168/1770 train_time:96235ms step_avg:572.83ms
step:169/1770 train_time:96342ms step_avg:570.07ms
step:170/1770 train_time:96449ms step_avg:567.35ms
step:171/1770 train_time:96556ms step_avg:564.65ms
step:172/1770 train_time:96663ms step_avg:561.99ms
step:173/1770 train_time:96771ms step_avg:559.37ms
step:174/1770 train_time:96879ms step_avg:556.77ms
step:175/1770 train_time:96986ms step_avg:554.20ms
step:176/1770 train_time:97093ms step_avg:551.67ms
step:177/1770 train_time:97200ms step_avg:549.15ms
step:178/1770 train_time:97307ms step_avg:546.67ms
step:179/1770 train_time:97414ms step_avg:544.21ms
step:180/1770 train_time:97521ms step_avg:541.78ms
step:181/1770 train_time:97628ms step_avg:539.38ms
step:182/1770 train_time:97735ms step_avg:537.01ms
step:183/1770 train_time:97842ms step_avg:534.66ms
step:184/1770 train_time:97949ms step_avg:532.33ms
step:185/1770 train_time:98056ms step_avg:530.03ms
step:186/1770 train_time:98163ms step_avg:527.76ms
step:187/1770 train_time:98270ms step_avg:525.51ms
step:188/1770 train_time:98377ms step_avg:523.28ms
step:189/1770 train_time:98484ms step_avg:521.08ms
step:190/1770 train_time:98591ms step_avg:518.90ms
step:191/1770 train_time:98699ms step_avg:516.75ms
step:192/1770 train_time:98806ms step_avg:514.61ms
step:193/1770 train_time:98913ms step_avg:512.50ms
step:194/1770 train_time:99020ms step_avg:510.41ms
step:195/1770 train_time:99127ms step_avg:508.34ms
step:196/1770 train_time:99233ms step_avg:506.29ms
step:197/1770 train_time:99340ms step_avg:504.26ms
step:198/1770 train_time:99447ms step_avg:502.26ms
step:199/1770 train_time:99554ms step_avg:500.27ms
step:200/1770 train_time:99667ms step_avg:498.34ms
step:201/1770 train_time:99775ms step_avg:496.39ms
step:202/1770 train_time:99881ms step_avg:494.46ms
step:203/1770 train_time:99987ms step_avg:492.55ms
step:204/1770 train_time:100095ms step_avg:490.66ms
step:205/1770 train_time:100202ms step_avg:488.79ms
step:206/1770 train_time:100309ms step_avg:486.94ms
step:207/1770 train_time:100416ms step_avg:485.10ms
step:208/1770 train_time:100523ms step_avg:483.28ms
step:209/1770 train_time:100630ms step_avg:481.48ms
step:210/1770 train_time:100743ms step_avg:479.73ms
step:211/1770 train_time:100847ms step_avg:477.95ms
step:212/1770 train_time:100954ms step_avg:476.20ms
step:213/1770 train_time:101065ms step_avg:474.48ms
step:214/1770 train_time:101169ms step_avg:472.75ms
step:215/1770 train_time:101277ms step_avg:471.06ms
step:216/1770 train_time:101384ms step_avg:469.37ms
step:217/1770 train_time:101491ms step_avg:467.70ms
step:218/1770 train_time:101598ms step_avg:466.05ms
step:219/1770 train_time:101705ms step_avg:464.41ms
step:220/1770 train_time:101812ms step_avg:462.78ms
step:221/1770 train_time:101920ms step_avg:461.17ms
step:222/1770 train_time:102027ms step_avg:459.58ms
step:223/1770 train_time:102134ms step_avg:458.00ms
step:224/1770 train_time:102240ms step_avg:456.43ms
step:225/1770 train_time:102348ms step_avg:454.88ms
step:226/1770 train_time:102455ms step_avg:453.34ms
step:227/1770 train_time:102562ms step_avg:451.81ms
step:228/1770 train_time:102669ms step_avg:450.30ms
step:229/1770 train_time:102777ms step_avg:448.81ms
step:230/1770 train_time:102884ms step_avg:447.32ms
step:231/1770 train_time:102992ms step_avg:445.85ms
step:232/1770 train_time:103099ms step_avg:444.39ms
step:233/1770 train_time:103206ms step_avg:442.95ms
step:234/1770 train_time:103313ms step_avg:441.51ms
step:235/1770 train_time:103420ms step_avg:440.09ms
step:236/1770 train_time:103528ms step_avg:438.68ms
step:237/1770 train_time:103635ms step_avg:437.28ms
step:238/1770 train_time:103742ms step_avg:435.89ms
step:239/1770 train_time:103849ms step_avg:434.51ms
step:240/1770 train_time:103956ms step_avg:433.15ms
step:241/1770 train_time:104064ms step_avg:431.80ms
step:242/1770 train_time:104170ms step_avg:430.46ms
step:243/1770 train_time:104277ms step_avg:429.13ms
step:244/1770 train_time:104385ms step_avg:427.81ms
step:245/1770 train_time:104492ms step_avg:426.50ms
step:246/1770 train_time:104599ms step_avg:425.20ms
step:247/1770 train_time:104708ms step_avg:423.92ms
step:248/1770 train_time:104816ms step_avg:422.65ms
step:249/1770 train_time:104920ms step_avg:421.37ms
step:250/1770 train_time:105027ms step_avg:420.11ms
step:250/1770 val_loss:5.5540 train_time:105031ms step_avg:420.12ms
step:251/1770 train_time:105138ms step_avg:418.88ms
step:252/1770 train_time:105246ms step_avg:417.64ms
step:253/1770 train_time:105353ms step_avg:416.41ms
step:254/1770 train_time:105460ms step_avg:415.20ms
step:255/1770 train_time:105568ms step_avg:413.99ms
step:256/1770 train_time:105677ms step_avg:412.80ms
step:257/1770 train_time:105786ms step_avg:411.62ms
step:258/1770 train_time:105891ms step_avg:410.43ms
step:259/1770 train_time:105998ms step_avg:409.26ms
step:260/1770 train_time:106105ms step_avg:408.10ms
step:261/1770 train_time:106212ms step_avg:406.94ms
step:262/1770 train_time:106319ms step_avg:405.80ms
step:263/1770 train_time:106426ms step_avg:404.66ms
step:264/1770 train_time:106534ms step_avg:403.54ms
step:265/1770 train_time:106642ms step_avg:402.42ms
step:266/1770 train_time:106749ms step_avg:401.31ms
step:267/1770 train_time:106858ms step_avg:400.22ms
step:268/1770 train_time:106965ms step_avg:399.12ms
step:269/1770 train_time:107073ms step_avg:398.04ms
step:270/1770 train_time:107181ms step_avg:396.97ms
step:271/1770 train_time:107289ms step_avg:395.90ms
step:272/1770 train_time:107397ms step_avg:394.84ms
step:273/1770 train_time:107504ms step_avg:393.79ms
step:274/1770 train_time:107612ms step_avg:392.74ms
step:275/1770 train_time:107720ms step_avg:391.71ms
step:276/1770 train_time:107827ms step_avg:390.68ms
step:277/1770 train_time:107935ms step_avg:389.66ms
step:278/1770 train_time:108043ms step_avg:388.64ms
step:279/1770 train_time:108151ms step_avg:387.64ms
step:280/1770 train_time:108259ms step_avg:386.64ms
step:281/1770 train_time:108367ms step_avg:385.65ms
step:282/1770 train_time:108475ms step_avg:384.66ms
step:283/1770 train_time:108583ms step_avg:383.69ms
step:284/1770 train_time:108690ms step_avg:382.71ms
step:285/1770 train_time:108798ms step_avg:381.75ms
step:286/1770 train_time:108905ms step_avg:380.79ms
step:287/1770 train_time:109013ms step_avg:379.84ms
step:288/1770 train_time:109124ms step_avg:378.90ms
step:289/1770 train_time:109228ms step_avg:377.95ms
step:290/1770 train_time:109336ms step_avg:377.02ms
step:291/1770 train_time:109443ms step_avg:376.09ms
step:292/1770 train_time:109551ms step_avg:375.17ms
step:293/1770 train_time:109659ms step_avg:374.26ms
step:294/1770 train_time:109767ms step_avg:373.36ms
step:295/1770 train_time:109874ms step_avg:372.46ms
step:296/1770 train_time:109982ms step_avg:371.56ms
step:297/1770 train_time:110093ms step_avg:370.68ms
step:298/1770 train_time:110198ms step_avg:369.79ms
step:299/1770 train_time:110306ms step_avg:368.92ms
GALA update on step 300: GalaAdam LRs=[0.324976, 18.653418, 0.000159], GalaMuon LRs=[0.000899]
step:300/1770 train_time:146548ms step_avg:488.49ms
step:301/1770 train_time:146656ms step_avg:487.23ms
step:302/1770 train_time:146761ms step_avg:485.96ms
step:303/1770 train_time:146869ms step_avg:484.71ms
step:304/1770 train_time:146976ms step_avg:483.47ms
step:305/1770 train_time:147082ms step_avg:482.24ms
step:306/1770 train_time:147190ms step_avg:481.01ms
step:307/1770 train_time:147297ms step_avg:479.80ms
step:308/1770 train_time:147405ms step_avg:478.59ms
step:309/1770 train_time:147511ms step_avg:477.38ms
step:310/1770 train_time:147619ms step_avg:476.19ms
step:311/1770 train_time:147726ms step_avg:475.00ms
step:312/1770 train_time:147834ms step_avg:473.83ms
step:313/1770 train_time:147941ms step_avg:472.65ms
step:314/1770 train_time:148048ms step_avg:471.49ms
step:315/1770 train_time:148156ms step_avg:470.34ms
step:316/1770 train_time:148263ms step_avg:469.19ms
step:317/1770 train_time:148371ms step_avg:468.05ms
step:318/1770 train_time:148478ms step_avg:466.91ms
step:319/1770 train_time:148586ms step_avg:465.79ms
step:320/1770 train_time:148693ms step_avg:464.67ms
step:321/1770 train_time:148800ms step_avg:463.55ms
step:322/1770 train_time:148908ms step_avg:462.45ms
step:323/1770 train_time:149015ms step_avg:461.35ms
step:324/1770 train_time:149122ms step_avg:460.25ms
step:325/1770 train_time:149230ms step_avg:459.17ms
step:326/1770 train_time:149337ms step_avg:458.09ms
step:327/1770 train_time:149445ms step_avg:457.02ms
step:328/1770 train_time:149552ms step_avg:455.95ms
step:329/1770 train_time:149660ms step_avg:454.89ms
step:330/1770 train_time:149767ms step_avg:453.84ms
step:331/1770 train_time:149875ms step_avg:452.80ms
step:332/1770 train_time:149982ms step_avg:451.75ms
step:333/1770 train_time:150093ms step_avg:450.73ms
step:334/1770 train_time:150197ms step_avg:449.69ms
step:335/1770 train_time:150305ms step_avg:448.67ms
step:336/1770 train_time:150412ms step_avg:447.66ms
step:337/1770 train_time:150520ms step_avg:446.65ms
step:338/1770 train_time:150627ms step_avg:445.64ms
step:339/1770 train_time:150735ms step_avg:444.64ms
step:340/1770 train_time:150842ms step_avg:443.65ms
step:341/1770 train_time:150955ms step_avg:442.68ms
step:342/1770 train_time:151060ms step_avg:441.70ms
step:343/1770 train_time:151168ms step_avg:440.72ms
step:344/1770 train_time:151275ms step_avg:439.75ms
step:345/1770 train_time:151385ms step_avg:438.80ms
step:346/1770 train_time:151490ms step_avg:437.83ms
step:347/1770 train_time:151598ms step_avg:436.88ms
step:348/1770 train_time:151704ms step_avg:435.93ms
step:349/1770 train_time:151812ms step_avg:434.99ms
step:350/1770 train_time:151919ms step_avg:434.05ms
step:351/1770 train_time:152027ms step_avg:433.13ms
step:352/1770 train_time:152135ms step_avg:432.20ms
step:353/1770 train_time:152242ms step_avg:431.28ms
step:354/1770 train_time:152350ms step_avg:430.37ms
step:355/1770 train_time:152457ms step_avg:429.46ms
step:356/1770 train_time:152567ms step_avg:428.56ms
step:357/1770 train_time:152673ms step_avg:427.65ms
step:358/1770 train_time:152780ms step_avg:426.76ms
step:359/1770 train_time:152888ms step_avg:425.87ms
step:360/1770 train_time:152996ms step_avg:424.99ms
step:361/1770 train_time:153103ms step_avg:424.11ms
step:362/1770 train_time:153210ms step_avg:423.23ms
step:363/1770 train_time:153318ms step_avg:422.36ms
step:364/1770 train_time:153425ms step_avg:421.50ms
step:365/1770 train_time:153533ms step_avg:420.64ms
step:366/1770 train_time:153640ms step_avg:419.78ms
step:367/1770 train_time:153748ms step_avg:418.93ms
step:368/1770 train_time:153856ms step_avg:418.09ms
step:369/1770 train_time:153963ms step_avg:417.25ms
step:370/1770 train_time:154071ms step_avg:416.41ms
step:371/1770 train_time:154180ms step_avg:415.58ms
step:372/1770 train_time:154288ms step_avg:414.75ms
step:373/1770 train_time:154396ms step_avg:413.93ms
step:374/1770 train_time:154504ms step_avg:413.11ms
step:375/1770 train_time:154610ms step_avg:412.29ms
step:375/1770 val_loss:5.4144 train_time:154612ms step_avg:412.30ms
step:376/1770 train_time:154719ms step_avg:411.49ms
step:377/1770 train_time:154827ms step_avg:410.68ms
step:378/1770 train_time:154935ms step_avg:409.88ms
step:379/1770 train_time:155043ms step_avg:409.08ms
step:380/1770 train_time:155150ms step_avg:408.29ms
step:381/1770 train_time:155258ms step_avg:407.50ms
step:382/1770 train_time:155366ms step_avg:406.72ms
step:383/1770 train_time:155473ms step_avg:405.94ms
step:384/1770 train_time:155581ms step_avg:405.16ms
step:385/1770 train_time:155689ms step_avg:404.39ms
step:386/1770 train_time:155796ms step_avg:403.62ms
step:387/1770 train_time:155904ms step_avg:402.85ms
step:388/1770 train_time:156012ms step_avg:402.09ms
step:389/1770 train_time:156120ms step_avg:401.34ms
step:390/1770 train_time:156227ms step_avg:400.58ms
step:391/1770 train_time:156334ms step_avg:399.83ms
step:392/1770 train_time:156444ms step_avg:399.09ms
step:393/1770 train_time:156552ms step_avg:398.35ms
step:394/1770 train_time:156657ms step_avg:397.61ms
step:395/1770 train_time:156766ms step_avg:396.88ms
step:396/1770 train_time:156877ms step_avg:396.15ms
step:397/1770 train_time:156985ms step_avg:395.43ms
step:398/1770 train_time:157095ms step_avg:394.71ms
step:399/1770 train_time:157205ms step_avg:394.00ms
step:400/1770 train_time:157315ms step_avg:393.29ms
step:401/1770 train_time:157425ms step_avg:392.58ms
step:402/1770 train_time:157535ms step_avg:391.88ms
step:403/1770 train_time:157645ms step_avg:391.18ms
step:404/1770 train_time:157756ms step_avg:390.49ms
step:405/1770 train_time:157865ms step_avg:389.79ms
step:406/1770 train_time:157977ms step_avg:389.11ms
step:407/1770 train_time:158085ms step_avg:388.42ms
step:408/1770 train_time:158195ms step_avg:387.73ms
step:409/1770 train_time:158305ms step_avg:387.05ms
step:410/1770 train_time:158415ms step_avg:386.38ms
step:411/1770 train_time:158525ms step_avg:385.70ms
step:412/1770 train_time:158635ms step_avg:385.04ms
step:413/1770 train_time:158745ms step_avg:384.37ms
step:414/1770 train_time:158855ms step_avg:383.71ms
step:415/1770 train_time:158964ms step_avg:383.05ms
step:416/1770 train_time:159074ms step_avg:382.39ms
step:417/1770 train_time:159183ms step_avg:381.73ms
step:418/1770 train_time:159293ms step_avg:381.08ms
step:419/1770 train_time:159403ms step_avg:380.44ms
step:420/1770 train_time:159513ms step_avg:379.79ms
step:421/1770 train_time:159623ms step_avg:379.15ms
step:422/1770 train_time:159733ms step_avg:378.51ms
step:423/1770 train_time:159843ms step_avg:377.88ms
step:424/1770 train_time:159954ms step_avg:377.25ms
step:425/1770 train_time:160063ms step_avg:376.62ms
step:426/1770 train_time:160173ms step_avg:375.99ms
step:427/1770 train_time:160284ms step_avg:375.37ms
step:428/1770 train_time:160394ms step_avg:374.75ms
step:429/1770 train_time:160503ms step_avg:374.13ms
step:430/1770 train_time:160614ms step_avg:373.52ms
step:431/1770 train_time:160725ms step_avg:372.91ms
step:432/1770 train_time:160837ms step_avg:372.31ms
step:433/1770 train_time:160947ms step_avg:371.70ms
step:434/1770 train_time:161054ms step_avg:371.09ms
step:435/1770 train_time:161164ms step_avg:370.49ms
step:436/1770 train_time:161274ms step_avg:369.89ms
step:437/1770 train_time:161384ms step_avg:369.30ms
step:438/1770 train_time:161495ms step_avg:368.71ms
step:439/1770 train_time:161604ms step_avg:368.12ms
step:440/1770 train_time:161715ms step_avg:367.53ms
step:441/1770 train_time:161824ms step_avg:366.95ms
step:442/1770 train_time:161934ms step_avg:366.37ms
step:443/1770 train_time:162046ms step_avg:365.79ms
step:444/1770 train_time:162155ms step_avg:365.21ms
step:445/1770 train_time:162264ms step_avg:364.64ms
step:446/1770 train_time:162375ms step_avg:364.07ms
step:447/1770 train_time:162485ms step_avg:363.50ms
step:448/1770 train_time:162595ms step_avg:362.93ms
step:449/1770 train_time:162706ms step_avg:362.38ms
GALA update on step 450: GalaAdam LRs=[0.342739, 18.050045, 0.000155], GalaMuon LRs=[0.000679]
step:450/1770 train_time:204924ms step_avg:455.39ms
step:451/1770 train_time:205030ms step_avg:454.61ms
step:452/1770 train_time:205143ms step_avg:453.86ms
step:453/1770 train_time:205250ms step_avg:453.09ms
step:454/1770 train_time:205359ms step_avg:452.33ms
step:455/1770 train_time:205469ms step_avg:451.58ms
step:456/1770 train_time:205578ms step_avg:450.83ms
step:457/1770 train_time:205688ms step_avg:450.08ms
step:458/1770 train_time:205797ms step_avg:449.34ms
step:459/1770 train_time:205907ms step_avg:448.60ms
step:460/1770 train_time:206016ms step_avg:447.86ms
step:461/1770 train_time:206126ms step_avg:447.13ms
step:462/1770 train_time:206236ms step_avg:446.40ms
step:463/1770 train_time:206346ms step_avg:445.67ms
step:464/1770 train_time:206456ms step_avg:444.95ms
step:465/1770 train_time:206565ms step_avg:444.23ms
step:466/1770 train_time:206677ms step_avg:443.51ms
step:467/1770 train_time:206785ms step_avg:442.79ms
step:468/1770 train_time:206894ms step_avg:442.08ms
step:469/1770 train_time:207004ms step_avg:441.37ms
step:470/1770 train_time:207114ms step_avg:440.67ms
step:471/1770 train_time:207224ms step_avg:439.97ms
step:472/1770 train_time:207334ms step_avg:439.27ms
step:473/1770 train_time:207444ms step_avg:438.57ms
step:474/1770 train_time:207554ms step_avg:437.88ms
step:475/1770 train_time:207663ms step_avg:437.18ms
step:476/1770 train_time:207772ms step_avg:436.50ms
step:477/1770 train_time:207882ms step_avg:435.81ms
step:478/1770 train_time:207992ms step_avg:435.13ms
step:479/1770 train_time:208102ms step_avg:434.45ms
step:480/1770 train_time:208212ms step_avg:433.78ms
step:481/1770 train_time:208322ms step_avg:433.10ms
step:482/1770 train_time:208432ms step_avg:432.43ms
step:483/1770 train_time:208542ms step_avg:431.76ms
step:484/1770 train_time:208652ms step_avg:431.10ms
step:485/1770 train_time:208762ms step_avg:430.44ms
step:486/1770 train_time:208872ms step_avg:429.78ms
step:487/1770 train_time:208981ms step_avg:429.12ms
step:488/1770 train_time:209091ms step_avg:428.47ms
step:489/1770 train_time:209201ms step_avg:427.81ms
step:490/1770 train_time:209312ms step_avg:427.17ms
step:491/1770 train_time:209424ms step_avg:426.53ms
step:492/1770 train_time:209532ms step_avg:425.88ms
step:493/1770 train_time:209642ms step_avg:425.24ms
step:494/1770 train_time:209752ms step_avg:424.60ms
step:495/1770 train_time:209861ms step_avg:423.96ms
step:496/1770 train_time:209970ms step_avg:423.33ms
step:497/1770 train_time:210079ms step_avg:422.69ms
step:498/1770 train_time:210189ms step_avg:422.07ms
step:499/1770 train_time:210299ms step_avg:421.44ms
step:500/1770 train_time:210409ms step_avg:420.82ms
step:500/1770 val_loss:5.2232 train_time:210412ms step_avg:420.82ms
step:501/1770 train_time:210522ms step_avg:420.20ms
step:502/1770 train_time:210633ms step_avg:419.59ms
step:503/1770 train_time:210747ms step_avg:418.98ms
step:504/1770 train_time:210855ms step_avg:418.36ms
step:505/1770 train_time:210965ms step_avg:417.75ms
step:506/1770 train_time:211076ms step_avg:417.15ms
step:507/1770 train_time:211185ms step_avg:416.54ms
step:508/1770 train_time:211296ms step_avg:415.94ms
step:509/1770 train_time:211408ms step_avg:415.34ms
step:510/1770 train_time:211518ms step_avg:414.74ms
step:511/1770 train_time:211629ms step_avg:414.15ms
step:512/1770 train_time:211740ms step_avg:413.55ms
step:513/1770 train_time:211853ms step_avg:412.97ms
step:514/1770 train_time:211964ms step_avg:412.38ms
step:515/1770 train_time:212075ms step_avg:411.80ms
step:516/1770 train_time:212189ms step_avg:411.22ms
step:517/1770 train_time:212303ms step_avg:410.64ms
step:518/1770 train_time:212413ms step_avg:410.06ms
step:519/1770 train_time:212529ms step_avg:409.50ms
step:520/1770 train_time:212638ms step_avg:408.92ms
step:521/1770 train_time:212749ms step_avg:408.35ms
step:522/1770 train_time:212859ms step_avg:407.78ms
step:523/1770 train_time:212970ms step_avg:407.21ms
step:524/1770 train_time:213084ms step_avg:406.65ms
step:525/1770 train_time:213193ms step_avg:406.08ms
step:526/1770 train_time:213303ms step_avg:405.52ms
step:527/1770 train_time:213416ms step_avg:404.96ms
step:528/1770 train_time:213527ms step_avg:404.41ms
step:529/1770 train_time:213637ms step_avg:403.85ms
step:530/1770 train_time:213749ms step_avg:403.30ms
step:531/1770 train_time:213860ms step_avg:402.75ms
step:532/1770 train_time:213970ms step_avg:402.20ms
step:533/1770 train_time:214080ms step_avg:401.65ms
step:534/1770 train_time:214191ms step_avg:401.11ms
step:535/1770 train_time:214302ms step_avg:400.56ms
step:536/1770 train_time:214413ms step_avg:400.02ms
step:537/1770 train_time:214524ms step_avg:399.49ms
step:538/1770 train_time:214635ms step_avg:398.95ms
step:539/1770 train_time:214746ms step_avg:398.42ms
step:540/1770 train_time:214858ms step_avg:397.88ms
step:541/1770 train_time:214969ms step_avg:397.36ms
step:542/1770 train_time:215083ms step_avg:396.83ms
step:543/1770 train_time:215192ms step_avg:396.30ms
step:544/1770 train_time:215302ms step_avg:395.78ms
step:545/1770 train_time:215415ms step_avg:395.26ms
step:546/1770 train_time:215528ms step_avg:394.74ms
step:547/1770 train_time:215641ms step_avg:394.23ms
step:548/1770 train_time:215751ms step_avg:393.71ms
step:549/1770 train_time:215861ms step_avg:393.19ms
step:550/1770 train_time:215971ms step_avg:392.68ms
step:551/1770 train_time:216083ms step_avg:392.17ms
step:552/1770 train_time:216194ms step_avg:391.66ms
step:553/1770 train_time:216305ms step_avg:391.15ms
step:554/1770 train_time:216416ms step_avg:390.64ms
step:555/1770 train_time:216527ms step_avg:390.14ms
step:556/1770 train_time:216638ms step_avg:389.64ms
step:557/1770 train_time:216748ms step_avg:389.13ms
step:558/1770 train_time:216858ms step_avg:388.64ms
step:559/1770 train_time:216969ms step_avg:388.14ms
step:560/1770 train_time:217080ms step_avg:387.64ms
step:561/1770 train_time:217192ms step_avg:387.15ms
step:562/1770 train_time:217303ms step_avg:386.66ms
step:563/1770 train_time:217414ms step_avg:386.17ms
step:564/1770 train_time:217525ms step_avg:385.68ms
step:565/1770 train_time:217637ms step_avg:385.20ms
step:566/1770 train_time:217748ms step_avg:384.71ms
step:567/1770 train_time:217859ms step_avg:384.23ms
step:568/1770 train_time:217970ms step_avg:383.75ms
step:569/1770 train_time:218080ms step_avg:383.27ms
step:570/1770 train_time:218191ms step_avg:382.79ms
step:571/1770 train_time:218302ms step_avg:382.32ms
step:572/1770 train_time:218413ms step_avg:381.84ms
step:573/1770 train_time:218524ms step_avg:381.37ms
step:574/1770 train_time:218635ms step_avg:380.90ms
step:575/1770 train_time:218746ms step_avg:380.43ms
step:576/1770 train_time:218856ms step_avg:379.96ms
step:577/1770 train_time:218967ms step_avg:379.49ms
step:578/1770 train_time:219079ms step_avg:379.03ms
step:579/1770 train_time:219190ms step_avg:378.57ms
step:580/1770 train_time:219301ms step_avg:378.11ms
step:581/1770 train_time:219412ms step_avg:377.65ms
step:582/1770 train_time:219523ms step_avg:377.19ms
step:583/1770 train_time:219634ms step_avg:376.73ms
step:584/1770 train_time:219745ms step_avg:376.28ms
step:585/1770 train_time:219856ms step_avg:375.82ms
step:586/1770 train_time:219967ms step_avg:375.37ms
step:587/1770 train_time:220079ms step_avg:374.92ms
step:588/1770 train_time:220190ms step_avg:374.47ms
step:589/1770 train_time:220301ms step_avg:374.03ms
step:590/1770 train_time:220413ms step_avg:373.58ms
step:591/1770 train_time:220524ms step_avg:373.14ms
step:592/1770 train_time:220635ms step_avg:372.69ms
step:593/1770 train_time:220745ms step_avg:372.25ms
step:594/1770 train_time:220857ms step_avg:371.81ms
step:595/1770 train_time:220967ms step_avg:371.37ms
step:596/1770 train_time:221083ms step_avg:370.94ms
step:597/1770 train_time:221192ms step_avg:370.51ms
step:598/1770 train_time:221302ms step_avg:370.07ms
step:599/1770 train_time:221413ms step_avg:369.64ms
GALA update on step 600: GalaAdam LRs=[0.318861, 17.494511, 0.000163], GalaMuon LRs=[0.000431]
step:600/1770 train_time:221655ms step_avg:369.43ms
step:601/1770 train_time:221763ms step_avg:368.99ms
step:602/1770 train_time:221874ms step_avg:368.56ms
step:603/1770 train_time:221985ms step_avg:368.13ms
step:604/1770 train_time:222096ms step_avg:367.71ms
step:605/1770 train_time:222207ms step_avg:367.28ms
step:606/1770 train_time:222319ms step_avg:366.86ms
step:607/1770 train_time:222431ms step_avg:366.44ms
step:608/1770 train_time:222541ms step_avg:366.02ms
step:609/1770 train_time:222653ms step_avg:365.60ms
step:610/1770 train_time:222764ms step_avg:365.19ms
step:611/1770 train_time:222875ms step_avg:364.77ms
step:612/1770 train_time:222985ms step_avg:364.36ms
step:613/1770 train_time:223096ms step_avg:363.94ms
step:614/1770 train_time:223206ms step_avg:363.53ms
step:615/1770 train_time:223317ms step_avg:363.12ms
step:616/1770 train_time:223429ms step_avg:362.71ms
step:617/1770 train_time:223540ms step_avg:362.30ms
step:618/1770 train_time:223651ms step_avg:361.89ms
step:619/1770 train_time:223761ms step_avg:361.49ms
step:620/1770 train_time:223872ms step_avg:361.08ms
step:621/1770 train_time:223984ms step_avg:360.68ms
step:622/1770 train_time:224097ms step_avg:360.29ms
step:623/1770 train_time:224207ms step_avg:359.88ms
step:624/1770 train_time:224318ms step_avg:359.48ms
step:625/1770 train_time:224429ms step_avg:359.09ms
step:625/1770 val_loss:5.1036 train_time:224432ms step_avg:359.09ms
step:626/1770 train_time:224543ms step_avg:358.69ms
step:627/1770 train_time:224656ms step_avg:358.30ms
step:628/1770 train_time:224765ms step_avg:357.91ms
step:629/1770 train_time:224876ms step_avg:357.51ms
step:630/1770 train_time:224987ms step_avg:357.12ms
step:631/1770 train_time:225098ms step_avg:356.73ms
step:632/1770 train_time:225209ms step_avg:356.34ms
step:633/1770 train_time:225320ms step_avg:355.96ms
step:634/1770 train_time:225431ms step_avg:355.57ms
step:635/1770 train_time:225542ms step_avg:355.18ms
step:636/1770 train_time:225658ms step_avg:354.81ms
step:637/1770 train_time:225768ms step_avg:354.42ms
step:638/1770 train_time:225877ms step_avg:354.04ms
step:639/1770 train_time:225988ms step_avg:353.66ms
step:640/1770 train_time:226099ms step_avg:353.28ms
step:641/1770 train_time:226210ms step_avg:352.90ms
step:642/1770 train_time:226321ms step_avg:352.53ms
step:643/1770 train_time:226432ms step_avg:352.15ms
step:644/1770 train_time:226544ms step_avg:351.78ms
step:645/1770 train_time:226654ms step_avg:351.40ms
step:646/1770 train_time:226764ms step_avg:351.03ms
step:647/1770 train_time:226875ms step_avg:350.66ms
step:648/1770 train_time:226986ms step_avg:350.29ms
step:649/1770 train_time:227097ms step_avg:349.92ms
step:650/1770 train_time:227208ms step_avg:349.55ms
step:651/1770 train_time:227319ms step_avg:349.18ms
step:652/1770 train_time:227430ms step_avg:348.82ms
step:653/1770 train_time:227541ms step_avg:348.45ms
step:654/1770 train_time:227651ms step_avg:348.09ms
step:655/1770 train_time:227762ms step_avg:347.73ms
step:656/1770 train_time:227873ms step_avg:347.37ms
step:657/1770 train_time:227986ms step_avg:347.01ms
step:658/1770 train_time:228099ms step_avg:346.66ms
step:659/1770 train_time:228215ms step_avg:346.30ms
step:660/1770 train_time:228325ms step_avg:345.95ms
step:661/1770 train_time:228436ms step_avg:345.59ms
step:662/1770 train_time:228550ms step_avg:345.24ms
step:663/1770 train_time:228661ms step_avg:344.89ms
step:664/1770 train_time:228773ms step_avg:344.54ms
step:665/1770 train_time:228885ms step_avg:344.19ms
step:666/1770 train_time:228999ms step_avg:343.84ms
step:667/1770 train_time:229113ms step_avg:343.50ms
step:668/1770 train_time:229223ms step_avg:343.15ms
step:669/1770 train_time:229336ms step_avg:342.80ms
step:670/1770 train_time:229448ms step_avg:342.46ms
step:671/1770 train_time:229561ms step_avg:342.12ms
step:672/1770 train_time:229674ms step_avg:341.78ms
step:673/1770 train_time:229787ms step_avg:341.44ms
step:674/1770 train_time:229900ms step_avg:341.10ms
step:675/1770 train_time:230015ms step_avg:340.76ms
step:676/1770 train_time:230125ms step_avg:340.42ms
step:677/1770 train_time:230238ms step_avg:340.09ms
step:678/1770 train_time:230350ms step_avg:339.75ms
step:679/1770 train_time:230463ms step_avg:339.42ms
step:680/1770 train_time:230576ms step_avg:339.08ms
step:681/1770 train_time:230689ms step_avg:338.75ms
step:682/1770 train_time:230802ms step_avg:338.42ms
step:683/1770 train_time:230914ms step_avg:338.09ms
step:684/1770 train_time:231027ms step_avg:337.76ms
step:685/1770 train_time:231140ms step_avg:337.43ms
step:686/1770 train_time:231253ms step_avg:337.10ms
step:687/1770 train_time:231366ms step_avg:336.78ms
step:688/1770 train_time:231479ms step_avg:336.45ms
step:689/1770 train_time:231591ms step_avg:336.13ms
step:690/1770 train_time:231704ms step_avg:335.80ms
step:691/1770 train_time:231817ms step_avg:335.48ms
step:692/1770 train_time:231929ms step_avg:335.16ms
step:693/1770 train_time:232042ms step_avg:334.84ms
step:694/1770 train_time:232154ms step_avg:334.52ms
step:695/1770 train_time:232267ms step_avg:334.20ms
step:696/1770 train_time:232380ms step_avg:333.88ms
step:697/1770 train_time:232493ms step_avg:333.56ms
step:698/1770 train_time:232606ms step_avg:333.25ms
step:699/1770 train_time:232719ms step_avg:332.93ms
step:700/1770 train_time:232831ms step_avg:332.62ms
step:701/1770 train_time:232944ms step_avg:332.30ms
step:702/1770 train_time:233057ms step_avg:331.99ms
step:703/1770 train_time:233169ms step_avg:331.68ms
step:704/1770 train_time:233282ms step_avg:331.37ms
step:705/1770 train_time:233394ms step_avg:331.05ms
step:706/1770 train_time:233507ms step_avg:330.75ms
step:707/1770 train_time:233620ms step_avg:330.44ms
step:708/1770 train_time:233731ms step_avg:330.13ms
step:709/1770 train_time:233845ms step_avg:329.82ms
step:710/1770 train_time:233958ms step_avg:329.52ms
step:711/1770 train_time:234070ms step_avg:329.21ms
step:712/1770 train_time:234183ms step_avg:328.91ms
step:713/1770 train_time:234295ms step_avg:328.61ms
step:714/1770 train_time:234409ms step_avg:328.30ms
step:715/1770 train_time:234522ms step_avg:328.00ms
step:716/1770 train_time:234634ms step_avg:327.70ms
step:717/1770 train_time:234747ms step_avg:327.40ms
step:718/1770 train_time:234860ms step_avg:327.10ms
step:719/1770 train_time:234972ms step_avg:326.80ms
step:720/1770 train_time:235084ms step_avg:326.51ms
step:721/1770 train_time:235196ms step_avg:326.21ms
step:722/1770 train_time:235310ms step_avg:325.91ms
step:723/1770 train_time:235421ms step_avg:325.62ms
step:724/1770 train_time:235533ms step_avg:325.32ms
step:725/1770 train_time:235646ms step_avg:325.03ms
step:726/1770 train_time:235762ms step_avg:324.74ms
step:727/1770 train_time:235873ms step_avg:324.45ms
step:728/1770 train_time:235985ms step_avg:324.16ms
step:729/1770 train_time:236098ms step_avg:323.87ms
step:730/1770 train_time:236210ms step_avg:323.58ms
step:731/1770 train_time:236323ms step_avg:323.29ms
step:732/1770 train_time:236435ms step_avg:323.00ms
step:733/1770 train_time:236547ms step_avg:322.71ms
step:734/1770 train_time:236660ms step_avg:322.42ms
step:735/1770 train_time:236772ms step_avg:322.14ms
step:736/1770 train_time:236885ms step_avg:321.85ms
step:737/1770 train_time:236997ms step_avg:321.57ms
step:738/1770 train_time:237110ms step_avg:321.29ms
step:739/1770 train_time:237222ms step_avg:321.00ms
step:740/1770 train_time:237334ms step_avg:320.72ms
step:741/1770 train_time:237447ms step_avg:320.44ms
step:742/1770 train_time:237559ms step_avg:320.16ms
step:743/1770 train_time:237671ms step_avg:319.88ms
step:744/1770 train_time:237785ms step_avg:319.60ms
step:745/1770 train_time:237897ms step_avg:319.32ms
step:746/1770 train_time:238009ms step_avg:319.05ms
step:747/1770 train_time:238122ms step_avg:318.77ms
step:748/1770 train_time:238234ms step_avg:318.49ms
step:749/1770 train_time:238347ms step_avg:318.22ms
GALA update on step 750: GalaAdam LRs=[0.320886, 16.920443, 0.000166], GalaMuon LRs=[0.000409]
step:750/1770 train_time:238591ms step_avg:318.12ms
step:750/1770 val_loss:5.0297 train_time:238591ms step_avg:318.12ms
step:751/1770 train_time:238710ms step_avg:317.86ms
step:752/1770 train_time:238824ms step_avg:317.58ms
step:753/1770 train_time:238934ms step_avg:317.31ms
step:754/1770 train_time:239047ms step_avg:317.04ms
step:755/1770 train_time:239159ms step_avg:316.77ms
step:756/1770 train_time:239272ms step_avg:316.50ms
step:757/1770 train_time:239385ms step_avg:316.23ms
step:758/1770 train_time:239498ms step_avg:315.96ms
step:759/1770 train_time:239610ms step_avg:315.69ms
step:760/1770 train_time:239723ms step_avg:315.43ms
step:761/1770 train_time:239835ms step_avg:315.16ms
step:762/1770 train_time:239948ms step_avg:314.89ms
step:763/1770 train_time:240060ms step_avg:314.63ms
step:764/1770 train_time:240174ms step_avg:314.36ms
step:765/1770 train_time:240286ms step_avg:314.10ms
step:766/1770 train_time:240399ms step_avg:313.84ms
step:767/1770 train_time:240511ms step_avg:313.57ms
step:768/1770 train_time:240623ms step_avg:313.31ms
step:769/1770 train_time:240735ms step_avg:313.05ms
step:770/1770 train_time:240847ms step_avg:312.79ms
step:771/1770 train_time:240960ms step_avg:312.53ms
step:772/1770 train_time:241072ms step_avg:312.27ms
step:773/1770 train_time:241185ms step_avg:312.01ms
step:774/1770 train_time:241298ms step_avg:311.75ms
step:775/1770 train_time:241411ms step_avg:311.50ms
step:776/1770 train_time:241523ms step_avg:311.24ms
step:777/1770 train_time:241635ms step_avg:310.98ms
step:778/1770 train_time:241747ms step_avg:310.73ms
step:779/1770 train_time:241860ms step_avg:310.48ms
step:780/1770 train_time:241973ms step_avg:310.22ms
step:781/1770 train_time:242086ms step_avg:309.97ms
step:782/1770 train_time:242199ms step_avg:309.72ms
step:783/1770 train_time:242312ms step_avg:309.47ms
step:784/1770 train_time:242425ms step_avg:309.22ms
step:785/1770 train_time:242538ms step_avg:308.97ms
step:786/1770 train_time:242651ms step_avg:308.72ms
step:787/1770 train_time:242765ms step_avg:308.47ms
step:788/1770 train_time:242878ms step_avg:308.22ms
step:789/1770 train_time:242991ms step_avg:307.97ms
step:790/1770 train_time:243107ms step_avg:307.73ms
step:791/1770 train_time:243217ms step_avg:307.48ms
step:792/1770 train_time:243330ms step_avg:307.23ms
step:793/1770 train_time:243443ms step_avg:306.99ms
step:794/1770 train_time:243556ms step_avg:306.75ms
step:795/1770 train_time:243668ms step_avg:306.50ms
step:796/1770 train_time:243781ms step_avg:306.26ms
step:797/1770 train_time:243894ms step_avg:306.01ms
step:798/1770 train_time:244006ms step_avg:305.77ms
step:799/1770 train_time:244119ms step_avg:305.53ms
step:800/1770 train_time:244231ms step_avg:305.29ms
step:801/1770 train_time:244344ms step_avg:305.05ms
step:802/1770 train_time:244458ms step_avg:304.81ms
step:803/1770 train_time:244570ms step_avg:304.57ms
step:804/1770 train_time:244683ms step_avg:304.33ms
step:805/1770 train_time:244796ms step_avg:304.09ms
step:806/1770 train_time:244909ms step_avg:303.86ms
step:807/1770 train_time:245022ms step_avg:303.62ms
step:808/1770 train_time:245135ms step_avg:303.38ms
step:809/1770 train_time:245247ms step_avg:303.15ms
step:810/1770 train_time:245359ms step_avg:302.91ms
step:811/1770 train_time:245472ms step_avg:302.68ms
step:812/1770 train_time:245586ms step_avg:302.45ms
step:813/1770 train_time:245700ms step_avg:302.21ms
step:814/1770 train_time:245813ms step_avg:301.98ms
step:815/1770 train_time:245926ms step_avg:301.75ms
step:816/1770 train_time:246039ms step_avg:301.52ms
step:817/1770 train_time:246152ms step_avg:301.29ms
step:818/1770 train_time:246265ms step_avg:301.06ms
step:819/1770 train_time:246379ms step_avg:300.83ms
step:820/1770 train_time:246495ms step_avg:300.60ms
step:821/1770 train_time:246607ms step_avg:300.37ms
step:822/1770 train_time:246721ms step_avg:300.15ms
step:823/1770 train_time:246834ms step_avg:299.92ms
step:824/1770 train_time:246948ms step_avg:299.69ms
step:825/1770 train_time:247062ms step_avg:299.47ms
step:826/1770 train_time:247175ms step_avg:299.24ms
step:827/1770 train_time:247288ms step_avg:299.02ms
step:828/1770 train_time:247401ms step_avg:298.79ms
step:829/1770 train_time:247514ms step_avg:298.57ms
step:830/1770 train_time:247626ms step_avg:298.34ms
step:831/1770 train_time:247739ms step_avg:298.12ms
step:832/1770 train_time:247852ms step_avg:297.90ms
step:833/1770 train_time:247965ms step_avg:297.68ms
step:834/1770 train_time:248078ms step_avg:297.46ms
step:835/1770 train_time:248193ms step_avg:297.24ms
step:836/1770 train_time:248305ms step_avg:297.02ms
step:837/1770 train_time:248419ms step_avg:296.80ms
step:838/1770 train_time:248532ms step_avg:296.58ms
step:839/1770 train_time:248645ms step_avg:296.36ms
step:840/1770 train_time:248758ms step_avg:296.14ms
step:841/1770 train_time:248872ms step_avg:295.92ms
step:842/1770 train_time:248985ms step_avg:295.71ms
step:843/1770 train_time:249098ms step_avg:295.49ms
step:844/1770 train_time:249211ms step_avg:295.27ms
step:845/1770 train_time:249324ms step_avg:295.06ms
step:846/1770 train_time:249438ms step_avg:294.84ms
step:847/1770 train_time:249551ms step_avg:294.63ms
step:848/1770 train_time:249664ms step_avg:294.42ms
step:849/1770 train_time:249777ms step_avg:294.20ms
step:850/1770 train_time:249891ms step_avg:293.99ms
step:851/1770 train_time:250004ms step_avg:293.78ms
step:852/1770 train_time:250116ms step_avg:293.56ms
step:853/1770 train_time:250228ms step_avg:293.35ms
step:854/1770 train_time:250341ms step_avg:293.14ms
step:855/1770 train_time:250453ms step_avg:292.93ms
step:856/1770 train_time:250566ms step_avg:292.72ms
step:857/1770 train_time:250679ms step_avg:292.51ms
step:858/1770 train_time:250793ms step_avg:292.30ms
step:859/1770 train_time:250905ms step_avg:292.09ms
step:860/1770 train_time:251017ms step_avg:291.88ms
step:861/1770 train_time:251130ms step_avg:291.67ms
step:862/1770 train_time:251242ms step_avg:291.46ms
step:863/1770 train_time:251355ms step_avg:291.26ms
step:864/1770 train_time:251468ms step_avg:291.05ms
step:865/1770 train_time:251581ms step_avg:290.85ms
step:866/1770 train_time:251695ms step_avg:290.64ms
step:867/1770 train_time:251808ms step_avg:290.44ms
step:868/1770 train_time:251921ms step_avg:290.23ms
step:869/1770 train_time:252034ms step_avg:290.03ms
step:870/1770 train_time:252146ms step_avg:289.82ms
step:871/1770 train_time:252259ms step_avg:289.62ms
step:872/1770 train_time:252375ms step_avg:289.42ms
step:873/1770 train_time:252486ms step_avg:289.22ms
step:874/1770 train_time:252599ms step_avg:289.01ms
step:875/1770 train_time:252712ms step_avg:288.81ms
step:875/1770 val_loss:4.9790 train_time:252716ms step_avg:288.82ms
step:876/1770 train_time:252834ms step_avg:288.62ms
step:877/1770 train_time:252947ms step_avg:288.42ms
step:878/1770 train_time:253060ms step_avg:288.22ms
step:879/1770 train_time:253173ms step_avg:288.02ms
step:880/1770 train_time:253286ms step_avg:287.82ms
step:881/1770 train_time:253402ms step_avg:287.63ms
step:882/1770 train_time:253512ms step_avg:287.43ms
step:883/1770 train_time:253626ms step_avg:287.23ms
step:884/1770 train_time:253739ms step_avg:287.04ms
step:885/1770 train_time:253852ms step_avg:286.84ms
step:886/1770 train_time:253965ms step_avg:286.64ms
step:887/1770 train_time:254077ms step_avg:286.45ms
step:888/1770 train_time:254190ms step_avg:286.25ms
step:889/1770 train_time:254303ms step_avg:286.06ms
step:890/1770 train_time:254416ms step_avg:285.86ms
step:891/1770 train_time:254530ms step_avg:285.67ms
step:892/1770 train_time:254642ms step_avg:285.47ms
step:893/1770 train_time:254754ms step_avg:285.28ms
step:894/1770 train_time:254868ms step_avg:285.09ms
step:895/1770 train_time:254981ms step_avg:284.90ms
step:896/1770 train_time:255094ms step_avg:284.70ms
step:897/1770 train_time:255208ms step_avg:284.51ms
step:898/1770 train_time:255321ms step_avg:284.32ms
step:899/1770 train_time:255435ms step_avg:284.13ms
GALA update on step 900: GalaAdam LRs=[0.307883, 16.405020, 0.000161], GalaMuon LRs=[0.000383]
step:900/1770 train_time:255680ms step_avg:284.09ms
step:901/1770 train_time:255789ms step_avg:283.89ms
step:902/1770 train_time:255903ms step_avg:283.71ms
step:903/1770 train_time:256015ms step_avg:283.52ms
step:904/1770 train_time:256128ms step_avg:283.33ms
step:905/1770 train_time:256241ms step_avg:283.14ms
step:906/1770 train_time:256354ms step_avg:282.95ms
step:907/1770 train_time:256468ms step_avg:282.77ms
step:908/1770 train_time:256580ms step_avg:282.58ms
step:909/1770 train_time:256693ms step_avg:282.39ms
step:910/1770 train_time:256807ms step_avg:282.21ms
step:911/1770 train_time:256920ms step_avg:282.02ms
step:912/1770 train_time:257034ms step_avg:281.84ms
step:913/1770 train_time:257147ms step_avg:281.65ms
step:914/1770 train_time:257260ms step_avg:281.47ms
step:915/1770 train_time:257373ms step_avg:281.28ms
step:916/1770 train_time:257486ms step_avg:281.10ms
step:917/1770 train_time:257601ms step_avg:280.92ms
step:918/1770 train_time:257712ms step_avg:280.73ms
step:919/1770 train_time:257826ms step_avg:280.55ms
step:920/1770 train_time:257941ms step_avg:280.37ms
step:921/1770 train_time:258055ms step_avg:280.19ms
step:922/1770 train_time:258169ms step_avg:280.01ms
step:923/1770 train_time:258283ms step_avg:279.83ms
step:924/1770 train_time:258401ms step_avg:279.65ms
step:925/1770 train_time:258512ms step_avg:279.47ms
step:926/1770 train_time:258627ms step_avg:279.30ms
step:927/1770 train_time:258743ms step_avg:279.12ms
step:928/1770 train_time:258856ms step_avg:278.94ms
step:929/1770 train_time:258971ms step_avg:278.76ms
step:930/1770 train_time:259085ms step_avg:278.59ms
step:931/1770 train_time:259200ms step_avg:278.41ms
step:932/1770 train_time:259315ms step_avg:278.23ms
step:933/1770 train_time:259429ms step_avg:278.06ms
step:934/1770 train_time:259544ms step_avg:277.88ms
step:935/1770 train_time:259659ms step_avg:277.71ms
step:936/1770 train_time:259773ms step_avg:277.54ms
step:937/1770 train_time:259889ms step_avg:277.36ms
step:938/1770 train_time:260002ms step_avg:277.19ms
step:939/1770 train_time:260117ms step_avg:277.02ms
step:940/1770 train_time:260232ms step_avg:276.84ms
step:941/1770 train_time:260346ms step_avg:276.67ms
step:942/1770 train_time:260461ms step_avg:276.50ms
step:943/1770 train_time:260576ms step_avg:276.33ms
step:944/1770 train_time:260690ms step_avg:276.15ms
step:945/1770 train_time:260804ms step_avg:275.98ms
step:946/1770 train_time:260918ms step_avg:275.81ms
step:947/1770 train_time:261033ms step_avg:275.64ms
step:948/1770 train_time:261148ms step_avg:275.47ms
step:949/1770 train_time:261263ms step_avg:275.30ms
step:950/1770 train_time:261375ms step_avg:275.13ms
step:951/1770 train_time:261490ms step_avg:274.96ms
step:952/1770 train_time:261605ms step_avg:274.80ms
step:953/1770 train_time:261720ms step_avg:274.63ms
step:954/1770 train_time:261835ms step_avg:274.46ms
step:955/1770 train_time:261951ms step_avg:274.29ms
step:956/1770 train_time:262066ms step_avg:274.13ms
step:957/1770 train_time:262180ms step_avg:273.96ms
step:958/1770 train_time:262295ms step_avg:273.79ms
step:959/1770 train_time:262410ms step_avg:273.63ms
step:960/1770 train_time:262524ms step_avg:273.46ms
step:961/1770 train_time:262638ms step_avg:273.30ms
step:962/1770 train_time:262753ms step_avg:273.13ms
step:963/1770 train_time:262866ms step_avg:272.97ms
step:964/1770 train_time:262981ms step_avg:272.80ms
step:965/1770 train_time:263096ms step_avg:272.64ms
step:966/1770 train_time:263209ms step_avg:272.47ms
step:967/1770 train_time:263323ms step_avg:272.31ms
step:968/1770 train_time:263437ms step_avg:272.15ms
step:969/1770 train_time:263551ms step_avg:271.98ms
step:970/1770 train_time:263666ms step_avg:271.82ms
step:971/1770 train_time:263781ms step_avg:271.66ms
step:972/1770 train_time:263895ms step_avg:271.50ms
step:973/1770 train_time:264008ms step_avg:271.33ms
step:974/1770 train_time:264123ms step_avg:271.17ms
step:975/1770 train_time:264238ms step_avg:271.01ms
step:976/1770 train_time:264352ms step_avg:270.85ms
step:977/1770 train_time:264467ms step_avg:270.69ms
step:978/1770 train_time:264582ms step_avg:270.53ms
step:979/1770 train_time:264695ms step_avg:270.37ms
step:980/1770 train_time:264810ms step_avg:270.21ms
step:981/1770 train_time:264924ms step_avg:270.06ms
step:982/1770 train_time:265039ms step_avg:269.90ms
step:983/1770 train_time:265154ms step_avg:269.74ms
step:984/1770 train_time:265269ms step_avg:269.58ms
step:985/1770 train_time:265384ms step_avg:269.43ms
step:986/1770 train_time:265498ms step_avg:269.27ms
step:987/1770 train_time:265613ms step_avg:269.11ms
step:988/1770 train_time:265727ms step_avg:268.95ms
step:989/1770 train_time:265843ms step_avg:268.80ms
step:990/1770 train_time:265956ms step_avg:268.64ms
step:991/1770 train_time:266072ms step_avg:268.49ms
step:992/1770 train_time:266187ms step_avg:268.33ms
step:993/1770 train_time:266301ms step_avg:268.18ms
step:994/1770 train_time:266416ms step_avg:268.02ms
step:995/1770 train_time:266530ms step_avg:267.87ms
step:996/1770 train_time:266646ms step_avg:267.72ms
step:997/1770 train_time:266760ms step_avg:267.56ms
step:998/1770 train_time:266874ms step_avg:267.41ms
step:999/1770 train_time:266987ms step_avg:267.25ms
step:1000/1770 train_time:267103ms step_avg:267.10ms
step:1000/1770 val_loss:4.9018 train_time:267105ms step_avg:267.11ms
step:1001/1770 train_time:267220ms step_avg:266.95ms
step:1002/1770 train_time:267334ms step_avg:266.80ms
step:1003/1770 train_time:267449ms step_avg:266.65ms
step:1004/1770 train_time:267563ms step_avg:266.50ms
step:1005/1770 train_time:267678ms step_avg:266.35ms
step:1006/1770 train_time:267793ms step_avg:266.20ms
step:1007/1770 train_time:267908ms step_avg:266.05ms
step:1008/1770 train_time:268021ms step_avg:265.89ms
step:1009/1770 train_time:268136ms step_avg:265.74ms
step:1010/1770 train_time:268250ms step_avg:265.59ms
step:1011/1770 train_time:268365ms step_avg:265.45ms
step:1012/1770 train_time:268479ms step_avg:265.30ms
step:1013/1770 train_time:268593ms step_avg:265.15ms
step:1014/1770 train_time:268707ms step_avg:265.00ms
step:1015/1770 train_time:268821ms step_avg:264.85ms
step:1016/1770 train_time:268936ms step_avg:264.70ms
step:1017/1770 train_time:269051ms step_avg:264.55ms
step:1018/1770 train_time:269165ms step_avg:264.41ms
step:1019/1770 train_time:269282ms step_avg:264.26ms
step:1020/1770 train_time:269394ms step_avg:264.11ms
step:1021/1770 train_time:269508ms step_avg:263.96ms
step:1022/1770 train_time:269623ms step_avg:263.82ms
step:1023/1770 train_time:269737ms step_avg:263.67ms
step:1024/1770 train_time:269852ms step_avg:263.53ms
step:1025/1770 train_time:269968ms step_avg:263.38ms
step:1026/1770 train_time:270082ms step_avg:263.24ms
step:1027/1770 train_time:270196ms step_avg:263.09ms
step:1028/1770 train_time:270310ms step_avg:262.95ms
step:1029/1770 train_time:270425ms step_avg:262.80ms
step:1030/1770 train_time:270540ms step_avg:262.66ms
step:1031/1770 train_time:270654ms step_avg:262.52ms
step:1032/1770 train_time:270770ms step_avg:262.37ms
step:1033/1770 train_time:270885ms step_avg:262.23ms
step:1034/1770 train_time:271000ms step_avg:262.09ms
step:1035/1770 train_time:271114ms step_avg:261.95ms
step:1036/1770 train_time:271229ms step_avg:261.80ms
step:1037/1770 train_time:271344ms step_avg:261.66ms
step:1038/1770 train_time:271457ms step_avg:261.52ms
step:1039/1770 train_time:271572ms step_avg:261.38ms
step:1040/1770 train_time:271687ms step_avg:261.24ms
step:1041/1770 train_time:271801ms step_avg:261.10ms
step:1042/1770 train_time:271916ms step_avg:260.96ms
step:1043/1770 train_time:272031ms step_avg:260.82ms
step:1044/1770 train_time:272146ms step_avg:260.68ms
step:1045/1770 train_time:272260ms step_avg:260.54ms
step:1046/1770 train_time:272375ms step_avg:260.40ms
step:1047/1770 train_time:272489ms step_avg:260.26ms
step:1048/1770 train_time:272603ms step_avg:260.12ms
step:1049/1770 train_time:272718ms step_avg:259.98ms
GALA update on step 1050: GalaAdam LRs=[0.285667, 16.007898, 0.000149], GalaMuon LRs=[0.000327]
step:1050/1770 train_time:272970ms step_avg:259.97ms
step:1051/1770 train_time:273081ms step_avg:259.83ms
step:1052/1770 train_time:273196ms step_avg:259.69ms
step:1053/1770 train_time:273310ms step_avg:259.55ms
step:1054/1770 train_time:273424ms step_avg:259.42ms
step:1055/1770 train_time:273537ms step_avg:259.28ms
step:1056/1770 train_time:273653ms step_avg:259.14ms
step:1057/1770 train_time:273766ms step_avg:259.00ms
step:1058/1770 train_time:273881ms step_avg:258.87ms
step:1059/1770 train_time:273995ms step_avg:258.73ms
step:1060/1770 train_time:274111ms step_avg:258.60ms
step:1061/1770 train_time:274224ms step_avg:258.46ms
step:1062/1770 train_time:274340ms step_avg:258.32ms
step:1063/1770 train_time:274457ms step_avg:258.19ms
step:1064/1770 train_time:274569ms step_avg:258.05ms
step:1065/1770 train_time:274687ms step_avg:257.92ms
step:1066/1770 train_time:274800ms step_avg:257.79ms
step:1067/1770 train_time:274917ms step_avg:257.65ms
step:1068/1770 train_time:275031ms step_avg:257.52ms
step:1069/1770 train_time:275148ms step_avg:257.39ms
step:1070/1770 train_time:275262ms step_avg:257.25ms
step:1071/1770 train_time:275374ms step_avg:257.12ms
step:1072/1770 train_time:275490ms step_avg:256.99ms
step:1073/1770 train_time:275605ms step_avg:256.85ms
step:1074/1770 train_time:275719ms step_avg:256.72ms
step:1075/1770 train_time:275834ms step_avg:256.59ms
step:1076/1770 train_time:275949ms step_avg:256.46ms
step:1077/1770 train_time:276067ms step_avg:256.33ms
step:1078/1770 train_time:276183ms step_avg:256.20ms
step:1079/1770 train_time:276296ms step_avg:256.07ms
step:1080/1770 train_time:276410ms step_avg:255.94ms
step:1081/1770 train_time:276525ms step_avg:255.80ms
step:1082/1770 train_time:276639ms step_avg:255.67ms
step:1083/1770 train_time:276754ms step_avg:255.54ms
step:1084/1770 train_time:276869ms step_avg:255.41ms
step:1085/1770 train_time:276984ms step_avg:255.28ms
step:1086/1770 train_time:277098ms step_avg:255.15ms
step:1087/1770 train_time:277212ms step_avg:255.03ms
step:1088/1770 train_time:277326ms step_avg:254.90ms
step:1089/1770 train_time:277448ms step_avg:254.77ms
step:1090/1770 train_time:277562ms step_avg:254.64ms
step:1091/1770 train_time:277678ms step_avg:254.52ms
step:1092/1770 train_time:277793ms step_avg:254.39ms
step:1093/1770 train_time:277908ms step_avg:254.26ms
step:1094/1770 train_time:278022ms step_avg:254.13ms
step:1095/1770 train_time:278138ms step_avg:254.01ms
step:1096/1770 train_time:278252ms step_avg:253.88ms
step:1097/1770 train_time:278367ms step_avg:253.75ms
step:1098/1770 train_time:278481ms step_avg:253.63ms
step:1099/1770 train_time:278597ms step_avg:253.50ms
step:1100/1770 train_time:278711ms step_avg:253.37ms
step:1101/1770 train_time:278825ms step_avg:253.25ms
step:1102/1770 train_time:278940ms step_avg:253.12ms
step:1103/1770 train_time:279057ms step_avg:253.00ms
step:1104/1770 train_time:279171ms step_avg:252.87ms
step:1105/1770 train_time:279286ms step_avg:252.75ms
step:1106/1770 train_time:279401ms step_avg:252.62ms
step:1107/1770 train_time:279516ms step_avg:252.50ms
step:1108/1770 train_time:279631ms step_avg:252.37ms
step:1109/1770 train_time:279747ms step_avg:252.25ms
step:1110/1770 train_time:279860ms step_avg:252.13ms
step:1111/1770 train_time:279976ms step_avg:252.00ms
step:1112/1770 train_time:280090ms step_avg:251.88ms
step:1113/1770 train_time:280205ms step_avg:251.76ms
step:1114/1770 train_time:280319ms step_avg:251.63ms
step:1115/1770 train_time:280434ms step_avg:251.51ms
step:1116/1770 train_time:280549ms step_avg:251.39ms
step:1117/1770 train_time:280664ms step_avg:251.27ms
step:1118/1770 train_time:280779ms step_avg:251.14ms
step:1119/1770 train_time:280893ms step_avg:251.02ms
step:1120/1770 train_time:281008ms step_avg:250.90ms
step:1121/1770 train_time:281123ms step_avg:250.78ms
step:1122/1770 train_time:281238ms step_avg:250.66ms
step:1123/1770 train_time:281353ms step_avg:250.54ms
step:1124/1770 train_time:281468ms step_avg:250.42ms
step:1125/1770 train_time:281583ms step_avg:250.30ms
step:1125/1770 val_loss:4.8540 train_time:281586ms step_avg:250.30ms
step:1126/1770 train_time:281702ms step_avg:250.18ms
step:1127/1770 train_time:281817ms step_avg:250.06ms
step:1128/1770 train_time:281931ms step_avg:249.94ms
step:1129/1770 train_time:282047ms step_avg:249.82ms
step:1130/1770 train_time:282161ms step_avg:249.70ms
step:1131/1770 train_time:282275ms step_avg:249.58ms
step:1132/1770 train_time:282389ms step_avg:249.46ms
step:1133/1770 train_time:282505ms step_avg:249.34ms
step:1134/1770 train_time:282619ms step_avg:249.22ms
step:1135/1770 train_time:282735ms step_avg:249.11ms
step:1136/1770 train_time:282849ms step_avg:248.99ms
step:1137/1770 train_time:282964ms step_avg:248.87ms
step:1138/1770 train_time:283079ms step_avg:248.75ms
step:1139/1770 train_time:283194ms step_avg:248.63ms
step:1140/1770 train_time:283308ms step_avg:248.52ms
step:1141/1770 train_time:283423ms step_avg:248.40ms
step:1142/1770 train_time:283538ms step_avg:248.28ms
step:1143/1770 train_time:283652ms step_avg:248.16ms
step:1144/1770 train_time:283766ms step_avg:248.05ms
step:1145/1770 train_time:283882ms step_avg:247.93ms
step:1146/1770 train_time:283997ms step_avg:247.82ms
step:1147/1770 train_time:284111ms step_avg:247.70ms
step:1148/1770 train_time:284226ms step_avg:247.58ms
step:1149/1770 train_time:284340ms step_avg:247.47ms
step:1150/1770 train_time:284455ms step_avg:247.35ms
step:1151/1770 train_time:284568ms step_avg:247.24ms
step:1152/1770 train_time:284682ms step_avg:247.12ms
step:1153/1770 train_time:284797ms step_avg:247.01ms
step:1154/1770 train_time:284912ms step_avg:246.89ms
step:1155/1770 train_time:285028ms step_avg:246.78ms
step:1156/1770 train_time:285144ms step_avg:246.66ms
step:1157/1770 train_time:285258ms step_avg:246.55ms
step:1158/1770 train_time:285376ms step_avg:246.44ms
step:1159/1770 train_time:285488ms step_avg:246.32ms
step:1160/1770 train_time:285603ms step_avg:246.21ms
step:1161/1770 train_time:285718ms step_avg:246.10ms
step:1162/1770 train_time:285832ms step_avg:245.98ms
step:1163/1770 train_time:285947ms step_avg:245.87ms
step:1164/1770 train_time:286061ms step_avg:245.76ms
step:1165/1770 train_time:286175ms step_avg:245.64ms
step:1166/1770 train_time:286291ms step_avg:245.53ms
step:1167/1770 train_time:286405ms step_avg:245.42ms
step:1168/1770 train_time:286520ms step_avg:245.31ms
step:1169/1770 train_time:286634ms step_avg:245.20ms
step:1170/1770 train_time:286748ms step_avg:245.08ms
step:1171/1770 train_time:286862ms step_avg:244.97ms
step:1172/1770 train_time:286977ms step_avg:244.86ms
step:1173/1770 train_time:287093ms step_avg:244.75ms
step:1174/1770 train_time:287208ms step_avg:244.64ms
step:1175/1770 train_time:287323ms step_avg:244.53ms
step:1176/1770 train_time:287437ms step_avg:244.42ms
step:1177/1770 train_time:287552ms step_avg:244.31ms
step:1178/1770 train_time:287666ms step_avg:244.20ms
step:1179/1770 train_time:287781ms step_avg:244.09ms
step:1180/1770 train_time:287897ms step_avg:243.98ms
step:1181/1770 train_time:288013ms step_avg:243.87ms
step:1182/1770 train_time:288127ms step_avg:243.76ms
step:1183/1770 train_time:288246ms step_avg:243.66ms
step:1184/1770 train_time:288360ms step_avg:243.55ms
step:1185/1770 train_time:288476ms step_avg:243.44ms
step:1186/1770 train_time:288593ms step_avg:243.33ms
step:1187/1770 train_time:288709ms step_avg:243.23ms
step:1188/1770 train_time:288825ms step_avg:243.12ms
step:1189/1770 train_time:288940ms step_avg:243.01ms
step:1190/1770 train_time:289057ms step_avg:242.91ms
step:1191/1770 train_time:289172ms step_avg:242.80ms
step:1192/1770 train_time:289288ms step_avg:242.69ms
step:1193/1770 train_time:289404ms step_avg:242.59ms
step:1194/1770 train_time:289520ms step_avg:242.48ms
step:1195/1770 train_time:289635ms step_avg:242.37ms
step:1196/1770 train_time:289750ms step_avg:242.27ms
step:1197/1770 train_time:289865ms step_avg:242.16ms
step:1198/1770 train_time:289982ms step_avg:242.05ms
step:1199/1770 train_time:290097ms step_avg:241.95ms
GALA update on step 1200: GalaAdam LRs=[0.253194, 15.645800, 0.000155], GalaMuon LRs=[0.000279]
step:1200/1770 train_time:290348ms step_avg:241.96ms
step:1201/1770 train_time:290461ms step_avg:241.85ms
step:1202/1770 train_time:290577ms step_avg:241.74ms
step:1203/1770 train_time:290693ms step_avg:241.64ms
step:1204/1770 train_time:290809ms step_avg:241.54ms
step:1205/1770 train_time:290925ms step_avg:241.43ms
step:1206/1770 train_time:291041ms step_avg:241.33ms
step:1207/1770 train_time:291156ms step_avg:241.22ms
step:1208/1770 train_time:291271ms step_avg:241.12ms
step:1209/1770 train_time:291388ms step_avg:241.02ms
step:1210/1770 train_time:291505ms step_avg:240.91ms
step:1211/1770 train_time:291621ms step_avg:240.81ms
step:1212/1770 train_time:291737ms step_avg:240.71ms
step:1213/1770 train_time:291853ms step_avg:240.60ms
step:1214/1770 train_time:291968ms step_avg:240.50ms
step:1215/1770 train_time:292087ms step_avg:240.40ms
step:1216/1770 train_time:292200ms step_avg:240.30ms
step:1217/1770 train_time:292315ms step_avg:240.19ms
step:1218/1770 train_time:292430ms step_avg:240.09ms
step:1219/1770 train_time:292546ms step_avg:239.99ms
step:1220/1770 train_time:292662ms step_avg:239.89ms
step:1221/1770 train_time:292778ms step_avg:239.79ms
step:1222/1770 train_time:292894ms step_avg:239.68ms
step:1223/1770 train_time:293011ms step_avg:239.58ms
step:1224/1770 train_time:293126ms step_avg:239.48ms
step:1225/1770 train_time:293242ms step_avg:239.38ms
step:1226/1770 train_time:293359ms step_avg:239.28ms
step:1227/1770 train_time:293475ms step_avg:239.18ms
step:1228/1770 train_time:293591ms step_avg:239.08ms
step:1229/1770 train_time:293707ms step_avg:238.98ms
step:1230/1770 train_time:293823ms step_avg:238.88ms
step:1231/1770 train_time:293938ms step_avg:238.78ms
step:1232/1770 train_time:294055ms step_avg:238.68ms
step:1233/1770 train_time:294171ms step_avg:238.58ms
step:1234/1770 train_time:294287ms step_avg:238.48ms
step:1235/1770 train_time:294403ms step_avg:238.38ms
step:1236/1770 train_time:294518ms step_avg:238.28ms
step:1237/1770 train_time:294633ms step_avg:238.18ms
step:1238/1770 train_time:294748ms step_avg:238.08ms
step:1239/1770 train_time:294864ms step_avg:237.99ms
step:1240/1770 train_time:294980ms step_avg:237.89ms
step:1241/1770 train_time:295096ms step_avg:237.79ms
step:1242/1770 train_time:295211ms step_avg:237.69ms
step:1243/1770 train_time:295327ms step_avg:237.59ms
step:1244/1770 train_time:295442ms step_avg:237.49ms
step:1245/1770 train_time:295558ms step_avg:237.40ms
step:1246/1770 train_time:295675ms step_avg:237.30ms
step:1247/1770 train_time:295791ms step_avg:237.20ms
step:1248/1770 train_time:295907ms step_avg:237.10ms
step:1249/1770 train_time:296024ms step_avg:237.01ms
step:1250/1770 train_time:296140ms step_avg:236.91ms
step:1250/1770 val_loss:4.7985 train_time:296144ms step_avg:236.92ms
step:1251/1770 train_time:296260ms step_avg:236.82ms
step:1252/1770 train_time:296376ms step_avg:236.72ms
step:1253/1770 train_time:296492ms step_avg:236.63ms
step:1254/1770 train_time:296608ms step_avg:236.53ms
step:1255/1770 train_time:296723ms step_avg:236.43ms
step:1256/1770 train_time:296839ms step_avg:236.34ms
step:1257/1770 train_time:296955ms step_avg:236.24ms
step:1258/1770 train_time:297070ms step_avg:236.14ms
step:1259/1770 train_time:297186ms step_avg:236.05ms
step:1260/1770 train_time:297302ms step_avg:235.95ms
step:1261/1770 train_time:297419ms step_avg:235.86ms
step:1262/1770 train_time:297535ms step_avg:235.76ms
step:1263/1770 train_time:297651ms step_avg:235.67ms
step:1264/1770 train_time:297767ms step_avg:235.57ms
step:1265/1770 train_time:297883ms step_avg:235.48ms
step:1266/1770 train_time:297999ms step_avg:235.39ms
step:1267/1770 train_time:298115ms step_avg:235.29ms
step:1268/1770 train_time:298231ms step_avg:235.20ms
step:1269/1770 train_time:298347ms step_avg:235.10ms
step:1270/1770 train_time:298462ms step_avg:235.01ms
step:1271/1770 train_time:298579ms step_avg:234.92ms
step:1272/1770 train_time:298696ms step_avg:234.82ms
step:1273/1770 train_time:298811ms step_avg:234.73ms
step:1274/1770 train_time:298926ms step_avg:234.64ms
step:1275/1770 train_time:299043ms step_avg:234.54ms
step:1276/1770 train_time:299159ms step_avg:234.45ms
step:1277/1770 train_time:299275ms step_avg:234.36ms
step:1278/1770 train_time:299391ms step_avg:234.27ms
step:1279/1770 train_time:299507ms step_avg:234.17ms
step:1280/1770 train_time:299623ms step_avg:234.08ms
step:1281/1770 train_time:299739ms step_avg:233.99ms
step:1282/1770 train_time:299854ms step_avg:233.90ms
step:1283/1770 train_time:299970ms step_avg:233.80ms
step:1284/1770 train_time:300086ms step_avg:233.71ms
step:1285/1770 train_time:300202ms step_avg:233.62ms
step:1286/1770 train_time:300319ms step_avg:233.53ms
step:1287/1770 train_time:300434ms step_avg:233.44ms
step:1288/1770 train_time:300549ms step_avg:233.35ms
step:1289/1770 train_time:300665ms step_avg:233.25ms
step:1290/1770 train_time:300781ms step_avg:233.16ms
step:1291/1770 train_time:300897ms step_avg:233.07ms
step:1292/1770 train_time:301016ms step_avg:232.98ms
step:1293/1770 train_time:301128ms step_avg:232.89ms
step:1294/1770 train_time:301244ms step_avg:232.80ms
step:1295/1770 train_time:301360ms step_avg:232.71ms
step:1296/1770 train_time:301478ms step_avg:232.62ms
step:1297/1770 train_time:301591ms step_avg:232.53ms
step:1298/1770 train_time:301707ms step_avg:232.44ms
step:1299/1770 train_time:301824ms step_avg:232.35ms
step:1300/1770 train_time:301940ms step_avg:232.26ms
step:1301/1770 train_time:302057ms step_avg:232.17ms
step:1302/1770 train_time:302172ms step_avg:232.08ms
step:1303/1770 train_time:302288ms step_avg:231.99ms
step:1304/1770 train_time:302405ms step_avg:231.91ms
step:1305/1770 train_time:302521ms step_avg:231.82ms
step:1306/1770 train_time:302637ms step_avg:231.73ms
step:1307/1770 train_time:302759ms step_avg:231.64ms
step:1308/1770 train_time:302870ms step_avg:231.55ms
step:1309/1770 train_time:302987ms step_avg:231.46ms
step:1310/1770 train_time:303105ms step_avg:231.38ms
step:1311/1770 train_time:303220ms step_avg:231.29ms
step:1312/1770 train_time:303335ms step_avg:231.20ms
step:1313/1770 train_time:303450ms step_avg:231.11ms
step:1314/1770 train_time:303567ms step_avg:231.03ms
step:1315/1770 train_time:303685ms step_avg:230.94ms
step:1316/1770 train_time:303800ms step_avg:230.85ms
step:1317/1770 train_time:303916ms step_avg:230.76ms
step:1318/1770 train_time:304033ms step_avg:230.68ms
step:1319/1770 train_time:304149ms step_avg:230.59ms
step:1320/1770 train_time:304265ms step_avg:230.50ms
step:1321/1770 train_time:304381ms step_avg:230.42ms
step:1322/1770 train_time:304497ms step_avg:230.33ms
step:1323/1770 train_time:304613ms step_avg:230.24ms
step:1324/1770 train_time:304728ms step_avg:230.16ms
step:1325/1770 train_time:304844ms step_avg:230.07ms
step:1326/1770 train_time:304961ms step_avg:229.99ms
step:1327/1770 train_time:305077ms step_avg:229.90ms
step:1328/1770 train_time:305194ms step_avg:229.81ms
step:1329/1770 train_time:305310ms step_avg:229.73ms
step:1330/1770 train_time:305426ms step_avg:229.64ms
step:1331/1770 train_time:305542ms step_avg:229.56ms
step:1332/1770 train_time:305659ms step_avg:229.47ms
step:1333/1770 train_time:305774ms step_avg:229.39ms
step:1334/1770 train_time:305890ms step_avg:229.30ms
step:1335/1770 train_time:306006ms step_avg:229.22ms
step:1336/1770 train_time:306122ms step_avg:229.13ms
step:1337/1770 train_time:306238ms step_avg:229.05ms
step:1338/1770 train_time:306355ms step_avg:228.96ms
step:1339/1770 train_time:306470ms step_avg:228.88ms
step:1340/1770 train_time:306587ms step_avg:228.80ms
step:1341/1770 train_time:306704ms step_avg:228.71ms
step:1342/1770 train_time:306820ms step_avg:228.63ms
step:1343/1770 train_time:306936ms step_avg:228.54ms
step:1344/1770 train_time:307051ms step_avg:228.46ms
step:1345/1770 train_time:307168ms step_avg:228.38ms
step:1346/1770 train_time:307284ms step_avg:228.29ms
step:1347/1770 train_time:307400ms step_avg:228.21ms
step:1348/1770 train_time:307521ms step_avg:228.13ms
step:1349/1770 train_time:307634ms step_avg:228.05ms
GALA update on step 1350: GalaAdam LRs=[0.245913, 15.277559, 0.000141], GalaMuon LRs=[0.000251]
step:1350/1770 train_time:307885ms step_avg:228.06ms
step:1351/1770 train_time:307998ms step_avg:227.98ms
step:1352/1770 train_time:308114ms step_avg:227.89ms
step:1353/1770 train_time:308230ms step_avg:227.81ms
step:1354/1770 train_time:308347ms step_avg:227.73ms
step:1355/1770 train_time:308463ms step_avg:227.65ms
step:1356/1770 train_time:308578ms step_avg:227.57ms
step:1357/1770 train_time:308695ms step_avg:227.48ms
step:1358/1770 train_time:308811ms step_avg:227.40ms
step:1359/1770 train_time:308930ms step_avg:227.32ms
step:1360/1770 train_time:309043ms step_avg:227.24ms
step:1361/1770 train_time:309158ms step_avg:227.16ms
step:1362/1770 train_time:309274ms step_avg:227.07ms
step:1363/1770 train_time:309391ms step_avg:226.99ms
step:1364/1770 train_time:309506ms step_avg:226.91ms
step:1365/1770 train_time:309623ms step_avg:226.83ms
step:1366/1770 train_time:309739ms step_avg:226.75ms
step:1367/1770 train_time:309855ms step_avg:226.67ms
step:1368/1770 train_time:309972ms step_avg:226.59ms
step:1369/1770 train_time:310088ms step_avg:226.51ms
step:1370/1770 train_time:310203ms step_avg:226.43ms
step:1371/1770 train_time:310320ms step_avg:226.35ms
step:1372/1770 train_time:310437ms step_avg:226.27ms
step:1373/1770 train_time:310551ms step_avg:226.18ms
step:1374/1770 train_time:310667ms step_avg:226.10ms
step:1375/1770 train_time:310784ms step_avg:226.02ms
step:1375/1770 val_loss:4.7594 train_time:310788ms step_avg:226.03ms
step:1376/1770 train_time:310904ms step_avg:225.95ms
step:1377/1770 train_time:311021ms step_avg:225.87ms
step:1378/1770 train_time:311138ms step_avg:225.79ms
step:1379/1770 train_time:311255ms step_avg:225.71ms
step:1380/1770 train_time:311371ms step_avg:225.63ms
step:1381/1770 train_time:311489ms step_avg:225.55ms
step:1382/1770 train_time:311608ms step_avg:225.48ms
step:1383/1770 train_time:311724ms step_avg:225.40ms
step:1384/1770 train_time:311841ms step_avg:225.32ms
step:1385/1770 train_time:311957ms step_avg:225.24ms
step:1386/1770 train_time:312074ms step_avg:225.16ms
step:1387/1770 train_time:312190ms step_avg:225.08ms
step:1388/1770 train_time:312306ms step_avg:225.00ms
step:1389/1770 train_time:312423ms step_avg:224.93ms
step:1390/1770 train_time:312540ms step_avg:224.85ms
step:1391/1770 train_time:312656ms step_avg:224.77ms
step:1392/1770 train_time:312773ms step_avg:224.69ms
step:1393/1770 train_time:312888ms step_avg:224.61ms
step:1394/1770 train_time:313003ms step_avg:224.54ms
step:1395/1770 train_time:313120ms step_avg:224.46ms
step:1396/1770 train_time:313236ms step_avg:224.38ms
step:1397/1770 train_time:313352ms step_avg:224.30ms
step:1398/1770 train_time:313467ms step_avg:224.23ms
step:1399/1770 train_time:313585ms step_avg:224.15ms
step:1400/1770 train_time:313701ms step_avg:224.07ms
step:1401/1770 train_time:313818ms step_avg:224.00ms
step:1402/1770 train_time:313935ms step_avg:223.92ms
step:1403/1770 train_time:314051ms step_avg:223.84ms
step:1404/1770 train_time:314167ms step_avg:223.77ms
step:1405/1770 train_time:314284ms step_avg:223.69ms
step:1406/1770 train_time:314399ms step_avg:223.61ms
step:1407/1770 train_time:314516ms step_avg:223.54ms
step:1408/1770 train_time:314632ms step_avg:223.46ms
step:1409/1770 train_time:314748ms step_avg:223.38ms
step:1410/1770 train_time:314863ms step_avg:223.31ms
step:1411/1770 train_time:314979ms step_avg:223.23ms
step:1412/1770 train_time:315094ms step_avg:223.15ms
step:1413/1770 train_time:315209ms step_avg:223.08ms
step:1414/1770 train_time:315325ms step_avg:223.00ms
step:1415/1770 train_time:315441ms step_avg:222.93ms
step:1416/1770 train_time:315558ms step_avg:222.85ms
step:1417/1770 train_time:315673ms step_avg:222.78ms
step:1418/1770 train_time:315788ms step_avg:222.70ms
step:1419/1770 train_time:315904ms step_avg:222.62ms
step:1420/1770 train_time:316021ms step_avg:222.55ms
step:1421/1770 train_time:316137ms step_avg:222.48ms
step:1422/1770 train_time:316254ms step_avg:222.40ms
step:1423/1770 train_time:316372ms step_avg:222.33ms
step:1424/1770 train_time:316487ms step_avg:222.25ms
step:1425/1770 train_time:316604ms step_avg:222.18ms
step:1426/1770 train_time:316720ms step_avg:222.10ms
step:1427/1770 train_time:316835ms step_avg:222.03ms
step:1428/1770 train_time:316953ms step_avg:221.96ms
step:1429/1770 train_time:317069ms step_avg:221.88ms
step:1430/1770 train_time:317185ms step_avg:221.81ms
step:1431/1770 train_time:317302ms step_avg:221.73ms
step:1432/1770 train_time:317420ms step_avg:221.66ms
step:1433/1770 train_time:317536ms step_avg:221.59ms
step:1434/1770 train_time:317651ms step_avg:221.51ms
step:1435/1770 train_time:317768ms step_avg:221.44ms
step:1436/1770 train_time:317884ms step_avg:221.37ms
step:1437/1770 train_time:318001ms step_avg:221.29ms
step:1438/1770 train_time:318116ms step_avg:221.22ms
step:1439/1770 train_time:318233ms step_avg:221.15ms
step:1440/1770 train_time:318350ms step_avg:221.08ms
step:1441/1770 train_time:318467ms step_avg:221.00ms
step:1442/1770 train_time:318583ms step_avg:220.93ms
step:1443/1770 train_time:318699ms step_avg:220.86ms
step:1444/1770 train_time:318816ms step_avg:220.79ms
step:1445/1770 train_time:318934ms step_avg:220.72ms
step:1446/1770 train_time:319051ms step_avg:220.64ms
step:1447/1770 train_time:319169ms step_avg:220.57ms
step:1448/1770 train_time:319287ms step_avg:220.50ms
step:1449/1770 train_time:319405ms step_avg:220.43ms
step:1450/1770 train_time:319522ms step_avg:220.36ms
step:1451/1770 train_time:319639ms step_avg:220.29ms
step:1452/1770 train_time:319755ms step_avg:220.22ms
step:1453/1770 train_time:319873ms step_avg:220.15ms
step:1454/1770 train_time:319990ms step_avg:220.08ms
step:1455/1770 train_time:320107ms step_avg:220.00ms
step:1456/1770 train_time:320224ms step_avg:219.93ms
step:1457/1770 train_time:320342ms step_avg:219.86ms
step:1458/1770 train_time:320460ms step_avg:219.79ms
step:1459/1770 train_time:320578ms step_avg:219.72ms
step:1460/1770 train_time:320695ms step_avg:219.65ms
step:1461/1770 train_time:320811ms step_avg:219.58ms
step:1462/1770 train_time:320929ms step_avg:219.51ms
step:1463/1770 train_time:321046ms step_avg:219.44ms
step:1464/1770 train_time:321163ms step_avg:219.37ms
step:1465/1770 train_time:321279ms step_avg:219.30ms
step:1466/1770 train_time:321396ms step_avg:219.23ms
step:1467/1770 train_time:321512ms step_avg:219.16ms
step:1468/1770 train_time:321629ms step_avg:219.09ms
step:1469/1770 train_time:321746ms step_avg:219.02ms
step:1470/1770 train_time:321863ms step_avg:218.95ms
step:1471/1770 train_time:321978ms step_avg:218.88ms
step:1472/1770 train_time:322095ms step_avg:218.81ms
step:1473/1770 train_time:322213ms step_avg:218.75ms
step:1474/1770 train_time:322329ms step_avg:218.68ms
step:1475/1770 train_time:322446ms step_avg:218.61ms
step:1476/1770 train_time:322564ms step_avg:218.54ms
step:1477/1770 train_time:322682ms step_avg:218.47ms
step:1478/1770 train_time:322799ms step_avg:218.40ms
step:1479/1770 train_time:322917ms step_avg:218.33ms
step:1480/1770 train_time:323035ms step_avg:218.27ms
step:1481/1770 train_time:323151ms step_avg:218.20ms
step:1482/1770 train_time:323269ms step_avg:218.13ms
step:1483/1770 train_time:323387ms step_avg:218.06ms
step:1484/1770 train_time:323504ms step_avg:217.99ms
step:1485/1770 train_time:323621ms step_avg:217.93ms
step:1486/1770 train_time:323738ms step_avg:217.86ms
step:1487/1770 train_time:323855ms step_avg:217.79ms
step:1488/1770 train_time:323975ms step_avg:217.72ms
step:1489/1770 train_time:324089ms step_avg:217.66ms
step:1490/1770 train_time:324206ms step_avg:217.59ms
step:1491/1770 train_time:324322ms step_avg:217.52ms
step:1492/1770 train_time:324439ms step_avg:217.45ms
step:1493/1770 train_time:324556ms step_avg:217.39ms
step:1494/1770 train_time:324673ms step_avg:217.32ms
step:1495/1770 train_time:324789ms step_avg:217.25ms
step:1496/1770 train_time:324905ms step_avg:217.18ms
step:1497/1770 train_time:325022ms step_avg:217.12ms
step:1498/1770 train_time:325140ms step_avg:217.05ms
step:1499/1770 train_time:325260ms step_avg:216.98ms
GALA update on step 1500: GalaAdam LRs=[0.244734, 14.943495, 0.000143], GalaMuon LRs=[0.000247]
step:1500/1770 train_time:325514ms step_avg:217.01ms
step:1500/1770 val_loss:4.7303 train_time:325514ms step_avg:217.01ms
step:1501/1770 train_time:325631ms step_avg:216.94ms
step:1502/1770 train_time:325749ms step_avg:216.88ms
step:1503/1770 train_time:325866ms step_avg:216.81ms
step:1504/1770 train_time:325984ms step_avg:216.74ms
step:1505/1770 train_time:326101ms step_avg:216.68ms
step:1506/1770 train_time:326218ms step_avg:216.61ms
step:1507/1770 train_time:326337ms step_avg:216.55ms
step:1508/1770 train_time:326454ms step_avg:216.48ms
step:1509/1770 train_time:326572ms step_avg:216.42ms
step:1510/1770 train_time:326688ms step_avg:216.35ms
step:1511/1770 train_time:326806ms step_avg:216.28ms
step:1512/1770 train_time:326925ms step_avg:216.22ms
step:1513/1770 train_time:327042ms step_avg:216.15ms
step:1514/1770 train_time:327158ms step_avg:216.09ms
step:1515/1770 train_time:327275ms step_avg:216.02ms
step:1516/1770 train_time:327393ms step_avg:215.96ms
step:1517/1770 train_time:327511ms step_avg:215.89ms
step:1518/1770 train_time:327629ms step_avg:215.83ms
step:1519/1770 train_time:327747ms step_avg:215.76ms
step:1520/1770 train_time:327865ms step_avg:215.70ms
step:1521/1770 train_time:327983ms step_avg:215.64ms
step:1522/1770 train_time:328099ms step_avg:215.57ms
step:1523/1770 train_time:328217ms step_avg:215.51ms
step:1524/1770 train_time:328333ms step_avg:215.44ms
step:1525/1770 train_time:328451ms step_avg:215.38ms
step:1526/1770 train_time:328569ms step_avg:215.31ms
step:1527/1770 train_time:328687ms step_avg:215.25ms
step:1528/1770 train_time:328804ms step_avg:215.19ms
step:1529/1770 train_time:328921ms step_avg:215.12ms
step:1530/1770 train_time:329039ms step_avg:215.06ms
step:1531/1770 train_time:329155ms step_avg:214.99ms
step:1532/1770 train_time:329273ms step_avg:214.93ms
step:1533/1770 train_time:329390ms step_avg:214.87ms
step:1534/1770 train_time:329507ms step_avg:214.80ms
step:1535/1770 train_time:329624ms step_avg:214.74ms
step:1536/1770 train_time:329741ms step_avg:214.68ms
step:1537/1770 train_time:329858ms step_avg:214.61ms
step:1538/1770 train_time:329975ms step_avg:214.55ms
step:1539/1770 train_time:330092ms step_avg:214.48ms
step:1540/1770 train_time:330209ms step_avg:214.42ms
step:1541/1770 train_time:330327ms step_avg:214.36ms
step:1542/1770 train_time:330443ms step_avg:214.30ms
step:1543/1770 train_time:330560ms step_avg:214.23ms
step:1544/1770 train_time:330676ms step_avg:214.17ms
step:1545/1770 train_time:330793ms step_avg:214.11ms
step:1546/1770 train_time:330912ms step_avg:214.04ms
step:1547/1770 train_time:331029ms step_avg:213.98ms
step:1548/1770 train_time:331145ms step_avg:213.92ms
step:1549/1770 train_time:331262ms step_avg:213.86ms
step:1550/1770 train_time:331380ms step_avg:213.79ms
step:1551/1770 train_time:331498ms step_avg:213.73ms
step:1552/1770 train_time:331615ms step_avg:213.67ms
step:1553/1770 train_time:331732ms step_avg:213.61ms
step:1554/1770 train_time:331848ms step_avg:213.54ms
step:1555/1770 train_time:331965ms step_avg:213.48ms
step:1556/1770 train_time:332082ms step_avg:213.42ms
step:1557/1770 train_time:332199ms step_avg:213.36ms
step:1558/1770 train_time:332316ms step_avg:213.30ms
step:1559/1770 train_time:332433ms step_avg:213.23ms
step:1560/1770 train_time:332550ms step_avg:213.17ms
step:1561/1770 train_time:332667ms step_avg:213.11ms
step:1562/1770 train_time:332784ms step_avg:213.05ms
step:1563/1770 train_time:332901ms step_avg:212.99ms
step:1564/1770 train_time:333018ms step_avg:212.93ms
step:1565/1770 train_time:333134ms step_avg:212.87ms
step:1566/1770 train_time:333250ms step_avg:212.80ms
step:1567/1770 train_time:333367ms step_avg:212.74ms
step:1568/1770 train_time:333485ms step_avg:212.68ms
step:1569/1770 train_time:333602ms step_avg:212.62ms
step:1570/1770 train_time:333718ms step_avg:212.56ms
step:1571/1770 train_time:333834ms step_avg:212.50ms
step:1572/1770 train_time:333951ms step_avg:212.44ms
step:1573/1770 train_time:334067ms step_avg:212.38ms
step:1574/1770 train_time:334185ms step_avg:212.32ms
step:1575/1770 train_time:334302ms step_avg:212.26ms
step:1576/1770 train_time:334419ms step_avg:212.19ms
step:1577/1770 train_time:334537ms step_avg:212.14ms
step:1578/1770 train_time:334656ms step_avg:212.08ms
step:1579/1770 train_time:334779ms step_avg:212.02ms
step:1580/1770 train_time:334898ms step_avg:211.96ms
step:1581/1770 train_time:335014ms step_avg:211.90ms
step:1582/1770 train_time:335132ms step_avg:211.84ms
step:1583/1770 train_time:335248ms step_avg:211.78ms
step:1584/1770 train_time:335366ms step_avg:211.72ms
step:1585/1770 train_time:335483ms step_avg:211.66ms
step:1586/1770 train_time:335601ms step_avg:211.60ms
step:1587/1770 train_time:335718ms step_avg:211.54ms
step:1588/1770 train_time:335837ms step_avg:211.48ms
step:1589/1770 train_time:335955ms step_avg:211.43ms
step:1590/1770 train_time:336073ms step_avg:211.37ms
step:1591/1770 train_time:336188ms step_avg:211.31ms
step:1592/1770 train_time:336305ms step_avg:211.25ms
step:1593/1770 train_time:336424ms step_avg:211.19ms
step:1594/1770 train_time:336542ms step_avg:211.13ms
step:1595/1770 train_time:336660ms step_avg:211.07ms
step:1596/1770 train_time:336777ms step_avg:211.01ms
step:1597/1770 train_time:336893ms step_avg:210.95ms
step:1598/1770 train_time:337012ms step_avg:210.90ms
step:1599/1770 train_time:337130ms step_avg:210.84ms
step:1600/1770 train_time:337247ms step_avg:210.78ms
step:1601/1770 train_time:337363ms step_avg:210.72ms
step:1602/1770 train_time:337481ms step_avg:210.66ms
step:1603/1770 train_time:337597ms step_avg:210.60ms
step:1604/1770 train_time:337714ms step_avg:210.54ms
step:1605/1770 train_time:337831ms step_avg:210.49ms
step:1606/1770 train_time:337949ms step_avg:210.43ms
step:1607/1770 train_time:338066ms step_avg:210.37ms
step:1608/1770 train_time:338183ms step_avg:210.31ms
step:1609/1770 train_time:338302ms step_avg:210.26ms
step:1610/1770 train_time:338419ms step_avg:210.20ms
step:1611/1770 train_time:338536ms step_avg:210.14ms
step:1612/1770 train_time:338653ms step_avg:210.08ms
step:1613/1770 train_time:338772ms step_avg:210.03ms
step:1614/1770 train_time:338888ms step_avg:209.97ms
step:1615/1770 train_time:339005ms step_avg:209.91ms
step:1616/1770 train_time:339123ms step_avg:209.85ms
step:1617/1770 train_time:339242ms step_avg:209.80ms
step:1618/1770 train_time:339361ms step_avg:209.74ms
step:1619/1770 train_time:339477ms step_avg:209.68ms
step:1620/1770 train_time:339594ms step_avg:209.63ms
step:1621/1770 train_time:339713ms step_avg:209.57ms
step:1622/1770 train_time:339833ms step_avg:209.51ms
step:1623/1770 train_time:339950ms step_avg:209.46ms
step:1624/1770 train_time:340068ms step_avg:209.40ms
step:1625/1770 train_time:340187ms step_avg:209.35ms
step:1625/1770 val_loss:4.7118 train_time:340191ms step_avg:209.35ms
step:1626/1770 train_time:340307ms step_avg:209.29ms
step:1627/1770 train_time:340424ms step_avg:209.23ms
step:1628/1770 train_time:340542ms step_avg:209.18ms
step:1629/1770 train_time:340658ms step_avg:209.12ms
step:1630/1770 train_time:340777ms step_avg:209.07ms
step:1631/1770 train_time:340895ms step_avg:209.01ms
step:1632/1770 train_time:341012ms step_avg:208.95ms
step:1633/1770 train_time:341129ms step_avg:208.90ms
step:1634/1770 train_time:341245ms step_avg:208.84ms
step:1635/1770 train_time:341363ms step_avg:208.78ms
step:1636/1770 train_time:341480ms step_avg:208.73ms
step:1637/1770 train_time:341597ms step_avg:208.67ms
step:1638/1770 train_time:341714ms step_avg:208.62ms
step:1639/1770 train_time:341832ms step_avg:208.56ms
step:1640/1770 train_time:341949ms step_avg:208.51ms
step:1641/1770 train_time:342067ms step_avg:208.45ms
step:1642/1770 train_time:342184ms step_avg:208.39ms
step:1643/1770 train_time:342300ms step_avg:208.34ms
step:1644/1770 train_time:342416ms step_avg:208.28ms
step:1645/1770 train_time:342534ms step_avg:208.23ms
step:1646/1770 train_time:342653ms step_avg:208.17ms
step:1647/1770 train_time:342770ms step_avg:208.12ms
step:1648/1770 train_time:342887ms step_avg:208.06ms
step:1649/1770 train_time:343003ms step_avg:208.01ms
GALA update on step 1650: GalaAdam LRs=[0.240080, 14.639900, 0.000144], GalaMuon LRs=[0.000246]
step:1650/1770 train_time:343256ms step_avg:208.03ms
step:1651/1770 train_time:343370ms step_avg:207.98ms
step:1652/1770 train_time:343488ms step_avg:207.92ms
step:1653/1770 train_time:343607ms step_avg:207.87ms
step:1654/1770 train_time:343723ms step_avg:207.81ms
step:1655/1770 train_time:343839ms step_avg:207.76ms
step:1656/1770 train_time:343955ms step_avg:207.70ms
step:1657/1770 train_time:344073ms step_avg:207.65ms
step:1658/1770 train_time:344191ms step_avg:207.59ms
step:1659/1770 train_time:344308ms step_avg:207.54ms
step:1660/1770 train_time:344427ms step_avg:207.49ms
step:1661/1770 train_time:344546ms step_avg:207.43ms
step:1662/1770 train_time:344665ms step_avg:207.38ms
step:1663/1770 train_time:344780ms step_avg:207.32ms
step:1664/1770 train_time:344898ms step_avg:207.27ms
step:1665/1770 train_time:345015ms step_avg:207.22ms
step:1666/1770 train_time:345133ms step_avg:207.16ms
step:1667/1770 train_time:345250ms step_avg:207.11ms
step:1668/1770 train_time:345368ms step_avg:207.05ms
step:1669/1770 train_time:345488ms step_avg:207.00ms
step:1670/1770 train_time:345604ms step_avg:206.95ms
step:1671/1770 train_time:345720ms step_avg:206.89ms
step:1672/1770 train_time:345836ms step_avg:206.84ms
step:1673/1770 train_time:345953ms step_avg:206.79ms
step:1674/1770 train_time:346068ms step_avg:206.73ms
step:1675/1770 train_time:346187ms step_avg:206.68ms
step:1676/1770 train_time:346304ms step_avg:206.63ms
step:1677/1770 train_time:346421ms step_avg:206.57ms
step:1678/1770 train_time:346538ms step_avg:206.52ms
step:1679/1770 train_time:346656ms step_avg:206.47ms
step:1680/1770 train_time:346772ms step_avg:206.41ms
step:1681/1770 train_time:346889ms step_avg:206.36ms
step:1682/1770 train_time:347007ms step_avg:206.31ms
step:1683/1770 train_time:347125ms step_avg:206.25ms
step:1684/1770 train_time:347241ms step_avg:206.20ms
step:1685/1770 train_time:347358ms step_avg:206.15ms
step:1686/1770 train_time:347475ms step_avg:206.09ms
step:1687/1770 train_time:347593ms step_avg:206.04ms
step:1688/1770 train_time:347709ms step_avg:205.99ms
step:1689/1770 train_time:347827ms step_avg:205.94ms
step:1690/1770 train_time:347943ms step_avg:205.88ms
step:1691/1770 train_time:348060ms step_avg:205.83ms
step:1692/1770 train_time:348177ms step_avg:205.78ms
step:1693/1770 train_time:348294ms step_avg:205.73ms
step:1694/1770 train_time:348412ms step_avg:205.67ms
step:1695/1770 train_time:348528ms step_avg:205.62ms
step:1696/1770 train_time:348650ms step_avg:205.57ms
step:1697/1770 train_time:348764ms step_avg:205.52ms
step:1698/1770 train_time:348881ms step_avg:205.47ms
step:1699/1770 train_time:348998ms step_avg:205.41ms
step:1700/1770 train_time:349115ms step_avg:205.36ms
step:1701/1770 train_time:349232ms step_avg:205.31ms
step:1702/1770 train_time:349350ms step_avg:205.26ms
step:1703/1770 train_time:349467ms step_avg:205.21ms
step:1704/1770 train_time:349585ms step_avg:205.16ms
step:1705/1770 train_time:349702ms step_avg:205.10ms
step:1706/1770 train_time:349820ms step_avg:205.05ms
step:1707/1770 train_time:349940ms step_avg:205.00ms
step:1708/1770 train_time:350058ms step_avg:204.95ms
step:1709/1770 train_time:350176ms step_avg:204.90ms
step:1710/1770 train_time:350295ms step_avg:204.85ms
step:1711/1770 train_time:350413ms step_avg:204.80ms
step:1712/1770 train_time:350530ms step_avg:204.75ms
step:1713/1770 train_time:350649ms step_avg:204.70ms
step:1714/1770 train_time:350767ms step_avg:204.65ms
step:1715/1770 train_time:350886ms step_avg:204.60ms
step:1716/1770 train_time:351005ms step_avg:204.55ms
step:1717/1770 train_time:351124ms step_avg:204.50ms
step:1718/1770 train_time:351242ms step_avg:204.45ms
step:1719/1770 train_time:351358ms step_avg:204.40ms
step:1720/1770 train_time:351478ms step_avg:204.35ms
step:1721/1770 train_time:351594ms step_avg:204.30ms
step:1722/1770 train_time:351712ms step_avg:204.25ms
step:1723/1770 train_time:351830ms step_avg:204.20ms
step:1724/1770 train_time:351948ms step_avg:204.15ms
step:1725/1770 train_time:352066ms step_avg:204.10ms
step:1726/1770 train_time:352183ms step_avg:204.05ms
step:1727/1770 train_time:352300ms step_avg:204.00ms
step:1728/1770 train_time:352418ms step_avg:203.95ms
step:1729/1770 train_time:352536ms step_avg:203.90ms
step:1730/1770 train_time:352653ms step_avg:203.85ms
step:1731/1770 train_time:352771ms step_avg:203.80ms
step:1732/1770 train_time:352889ms step_avg:203.75ms
step:1733/1770 train_time:353007ms step_avg:203.70ms
step:1734/1770 train_time:353124ms step_avg:203.65ms
step:1735/1770 train_time:353243ms step_avg:203.60ms
step:1736/1770 train_time:353361ms step_avg:203.55ms
step:1737/1770 train_time:353479ms step_avg:203.50ms
step:1738/1770 train_time:353598ms step_avg:203.45ms
step:1739/1770 train_time:353715ms step_avg:203.40ms
step:1740/1770 train_time:353833ms step_avg:203.35ms
step:1741/1770 train_time:353951ms step_avg:203.30ms
step:1742/1770 train_time:354070ms step_avg:203.25ms
step:1743/1770 train_time:354188ms step_avg:203.21ms
step:1744/1770 train_time:354306ms step_avg:203.16ms
step:1745/1770 train_time:354424ms step_avg:203.11ms
step:1746/1770 train_time:354541ms step_avg:203.06ms
step:1747/1770 train_time:354658ms step_avg:203.01ms
step:1748/1770 train_time:354775ms step_avg:202.96ms
step:1749/1770 train_time:354895ms step_avg:202.91ms
step:1750/1770 train_time:355014ms step_avg:202.86ms
step:1750/1770 val_loss:4.6778 train_time:355017ms step_avg:202.87ms
step:1751/1770 train_time:355133ms step_avg:202.82ms
step:1752/1770 train_time:355252ms step_avg:202.77ms
step:1753/1770 train_time:355370ms step_avg:202.72ms
step:1754/1770 train_time:355487ms step_avg:202.67ms
step:1755/1770 train_time:355605ms step_avg:202.62ms
step:1756/1770 train_time:355723ms step_avg:202.58ms
step:1757/1770 train_time:355841ms step_avg:202.53ms
step:1758/1770 train_time:355959ms step_avg:202.48ms
step:1759/1770 train_time:356077ms step_avg:202.43ms
step:1760/1770 train_time:356193ms step_avg:202.38ms
step:1761/1770 train_time:356312ms step_avg:202.33ms
step:1762/1770 train_time:356430ms step_avg:202.29ms
step:1763/1770 train_time:356547ms step_avg:202.24ms
step:1764/1770 train_time:356666ms step_avg:202.19ms
step:1765/1770 train_time:356787ms step_avg:202.15ms
step:1766/1770 train_time:356910ms step_avg:202.10ms
step:1767/1770 train_time:357025ms step_avg:202.05ms
step:1768/1770 train_time:357144ms step_avg:202.00ms
step:1769/1770 train_time:357263ms step_avg:201.96ms
step:1770/1770 train_time:357381ms step_avg:201.91ms
step:1770/1770 val_loss:4.6760 train_time:357384ms step_avg:201.91ms
peak memory allocated: 32059 MiB reserved: 45992 MiB
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment