Note
AuroraGPT-7B
Large Language Model trained at Argonne National Laboratory
Model Details
| Parameter | Β |
|---|---|
activation |
SwigLU |
activation_ckpt |
False |
dtype |
bfloat16 |
ffn_hidden_size |
11008 |
hidden_dim |
4096 |
n_layers |
32 |
n_heads |
32 |
n_kv_heads |
32 |
norm |
RMSNorm |
seq_len |
4096 |
ZeRO Stage |
1 |
-
π‘ Setup Environment: We use π
ezpz1 for setting up our environment:source <(curl 'https://raw.githubusercontent.com/saforem2/ezpz/refs/heads/main/src/ezpz/bin/utils.sh') ezpz_setup_env
-
π¦ Install Dependencies2:
-
python3 -m pip install "git+https://github.com/saforem2/ezpz" --require-virtualenv -
π€ {
transformers,datasets}:python3 -m pip install --upgrade transformers datasets
-
MASTER_PORT=5432
MASTER_ADDR=$(hostname)
WORLD_SIZE="${NGPUS}"
launch python3 src/agpt/hf_trainer.py \
--dataset_name stanfordnlp/imdb \
--model_name_or_path argonne-private/AuroraGPT-7B \
--bf16 \
--per_device_train_batch_size=1 \
--block_size=8 \
--gradient_checkpointing \
--do_train \
--deepspeed=zero_stage1_config.json \
--output_dir=trainer_output-$(tstamp)Output:
#[π aurora_nre_models_frameworks-2024.2.1_u1](π» aurora_nre_models_frameworks-2024.2.1_u1)
#[08:40:05 PM][x4110c1s5b0n0][/f/d/f/p/a/AuroraGPT][π± main][ππ€·β] [β±οΈ 42s]
$ MASTER_ADDR=$(hostname) MASTER_PORT=5432 WORLD_SIZE="${NGPUS}" launch python3 src/agpt/hf_trainer.py --dataset_name "HuggingFaceTB/finemath" --dataset_config_name="finemath-4plus" --model_name_or_path argonne-private/AuroraGPT-7B --bf16 --per_device_train_batch_size=1 --block_size=256 --gradient_checkpointing --do_train --deepspeed=zero_stage1_config.json --output_dir=trainer_output-$(tstamp)
[2025-03-06 20:41:15][I][datasets/config:54:datasets] PyTorch version 2.3.1+cxx11.abi available.
[2025-03-06 20:41:15][I][datasets/config:112:datasets] TensorFlow version 2.15.1 available.
[2025-03-06 20:41:15][I][datasets/config:125:datasets] JAX version 0.5.0 available.
[2025-03-06 20:41:20][I][ezpz/dist:658] Caught MASTER_PORT=5432 from environment!
[2025-03-06 20:41:20][I][ezpz/dist:531] Using get_torch_device_type()='xpu' with backend='ccl'
[2025-03-06 20:41:20][I][ezpz/dist:895] ['x4110c1s5b0n0'][ 2/11]
[2025-03-06 20:41:20][I][ezpz/dist:895] ['x4110c1s5b0n0'][ 7/11]
[2025-03-06 20:41:20][I][ezpz/dist:895] ['x4110c1s5b0n0'][ 9/11]
[2025-03-06 20:41:20][I][ezpz/dist:895] ['x4110c1s5b0n0'][ 1/11]
[2025-03-06 20:41:20][I][ezpz/dist:895] ['x4110c1s5b0n0'][ 6/11]
[2025-03-06 20:41:20][I][ezpz/dist:895] ['x4110c1s5b0n0'][10/11]
[2025-03-06 20:41:20][I][ezpz/dist:895] ['x4110c1s5b0n0'][ 8/11]
[2025-03-06 20:41:20][I][ezpz/dist:895] ['x4110c1s5b0n0'][ 3/11]
[2025-03-06 20:41:20][I][ezpz/dist:895] ['x4110c1s5b0n0'][11/11]
[2025-03-06 20:41:20][I][ezpz/dist:895] ['x4110c1s5b0n0'][ 5/11]
[2025-03-06 20:41:20][I][ezpz/dist:895] ['x4110c1s5b0n0'][ 4/11]
[2025-03-06 20:41:20][I][ezpz/dist:845] Using device='xpu' with backend='DDP' + 'ccl' for distributed training.
[2025-03-06 20:41:20][I][ezpz/dist:895] ['x4110c1s5b0n0'][ 0/11]
[2025-03-06 20:41:20,942] [INFO] [comm.py:161:init_deepspeed_backend] Initialize ccl backend
[2025-03-06 20:41:20,942] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-06 20:41:20,943] [INFO] [comm.py:161:init_deepspeed_backend] Initialize ccl backend
[2025-03-06 20:41:20,943] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-06 20:41:20,943] [INFO] [comm.py:161:init_deepspeed_backend] Initialize ccl backend
[2025-03-06 20:41:20,943] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-06 20:41:20,944] [INFO] [comm.py:161:init_deepspeed_backend] Initialize ccl backend
[2025-03-06 20:41:20,944] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-06 20:41:20,945] [INFO] [comm.py:161:init_deepspeed_backend] Initialize ccl backend
[2025-03-06 20:41:20,945] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-06 20:41:20,945] [INFO] [comm.py:161:init_deepspeed_backend] Initialize ccl backend
[2025-03-06 20:41:20,945] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-06 20:41:20,945] [INFO] [comm.py:161:init_deepspeed_backend] Initialize ccl backend
[2025-03-06 20:41:20,945] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-06 20:41:20,946] [INFO] [comm.py:161:init_deepspeed_backend] Initialize ccl backend
[2025-03-06 20:41:20,946] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-06 20:41:20,946] [INFO] [comm.py:161:init_deepspeed_backend] Initialize ccl backend
[2025-03-06 20:41:20,946] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-06 20:41:20,953] [INFO] [comm.py:161:init_deepspeed_backend] Initialize ccl backend
[2025-03-06 20:41:20,953] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-06 20:41:20,956] [INFO] [comm.py:161:init_deepspeed_backend] Initialize ccl backend
[2025-03-06 20:41:20,956] [INFO] [comm.py:161:init_deepspeed_backend] Initialize ccl backend
[2025-03-06 20:41:20,956] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-06 20:41:20,956] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-03-06 20:41:20][W][utils/_logger:68:__main__] Process rank: 0, device: xpu:0, n_gpu: 1, distributed training: True, 16-bits training: False
[2025-03-06 20:41:20][I][agpt/hf_trainer:324:__main__] Training/evaluation parameters TrainingArguments(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
average_tokens_across_devices=False,
batch_eval_metrics=False,
bf16=True,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=zero_stage1_config.json,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_delay=0,
eval_do_concat_batches=True,
eval_on_start=False,
eval_steps=None,
eval_strategy=no,
eval_use_gather_object=False,
evaluation_strategy=None,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
gradient_checkpointing_kwargs=None,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_always_push=False,
hub_model_id=None,
hub_private_repo=None,
hub_strategy=every_save,
hub_token=<HUB_TOKEN>,
ignore_data_skip=False,
include_for_metrics=[],
include_inputs_for_metrics=False,
include_num_input_tokens_seen=False,
include_tokens_per_second=False,
jit_mode_eval=False,
label_names=None,
label_smoothing_factor=0.0,
learning_rate=5e-05,
length_column_name=length,
load_best_model_at_end=False,
local_rank=0,
log_level=passive,
log_level_replica=warning,
log_on_each_node=True,
logging_dir=trainer_output-2025-03-06-204030/runs/Mar06_20-41-20_x4110c1s5b0n0,
logging_first_step=False,
logging_nan_inf_filter=True,
logging_steps=500,
logging_strategy=steps,
lr_scheduler_kwargs={},
lr_scheduler_type=linear,
max_grad_norm=1.0,
max_steps=-1,
metric_for_best_model=None,
mp_parameters=,
neftune_noise_alpha=None,
no_cuda=False,
num_train_epochs=3.0,
optim=adamw_torch,
optim_args=None,
optim_target_modules=None,
output_dir=trainer_output-2025-03-06-204030,
overwrite_output_dir=False,
past_index=-1,
per_device_eval_batch_size=8,
per_device_train_batch_size=1,
prediction_loss_only=False,
push_to_hub=False,
push_to_hub_model_id=None,
push_to_hub_organization=None,
push_to_hub_token=<PUSH_TO_HUB_TOKEN>,
ray_scope=last,
remove_unused_columns=True,
report_to=['tensorboard', 'wandb'],
restore_callback_states_from_checkpoint=False,
resume_from_checkpoint=None,
run_name=trainer_output-2025-03-06-204030,
save_on_each_node=False,
save_only_model=False,
save_safetensors=True,
save_steps=500,
save_strategy=steps,
save_total_limit=None,
seed=42,
skip_memory_metrics=True,
split_batches=None,
tf32=None,
torch_compile=False,
torch_compile_backend=None,
torch_compile_mode=None,
torch_empty_cache_steps=None,
torchdynamo=None,
tpu_metrics_debug=False,
tpu_num_cores=None,
use_cpu=False,
use_ipex=False,
use_legacy_prediction_loop=False,
use_liger_kernel=False,
use_mps_device=False,
warmup_ratio=0.0,
warmup_steps=0,
weight_decay=0.0,
)
Downloading data: 100%|ββββββββββ| 64/64 [02:59<00:00, 2.81s/files]
Generating train split: 100%|ββββββββββ| 6699493/6699493 [01:31<00:00, 73582.91 examples/s]
[INFO|tokenization_auto.py:730] 2025-03-06 20:45:56,194 >> Could not locate the tokenizer configuration file, will try to use the model config instead.
[INFO|tokenization_utils_base.py:2050] 2025-03-06 20:45:56,484 >> loading file tokenizer.model from cache at /home/foremans/.cache/huggingface/hub/models--argonne-private--AuroraGPT-7B/snapshots/ab1a9c95f3f913bb58570598f54e99bce68c2608/tokenizer.model
[INFO|tokenization_utils_base.py:2050] 2025-03-06 20:45:56,484 >> loading file tokenizer.json from cache at None
[INFO|tokenization_utils_base.py:2050] 2025-03-06 20:45:56,484 >> loading file added_tokens.json from cache at None
[INFO|tokenization_utils_base.py:2050] 2025-03-06 20:45:56,484 >> loading file special_tokens_map.json from cache at None
[INFO|tokenization_utils_base.py:2050] 2025-03-06 20:45:56,484 >> loading file tokenizer_config.json from cache at None
[INFO|tokenization_utils_base.py:2050] 2025-03-06 20:45:56,484 >> loading file chat_template.jinja from cache at None
[INFO|configuration_utils.py:699] 2025-03-06 20:45:56,488 >> loading configuration file config.json from cache at /home/foremans/.cache/huggingface/hub/models--argonne-private--AuroraGPT-7B/snapshots/ab1a9c95f3f913bb58570598f54e99bce68c2608/config.json
[INFO|configuration_utils.py:771] 2025-03-06 20:45:56,489 >> Model config LlamaConfig {
"_name_or_path": "argonne-private/AuroraGPT-7B",
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 4096,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 10000,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.49.0",
"use_cache": true,
"vocab_size": 32000
}
2025-03-06 17:10:28,340 - _logger.py - IPEX - INFO - Currently split master weight for xpu only support sgd
[WARNING|_logger.py:68] 2025-03-06 20:45:56,626 >> You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message
[WARNING|_logger.py:68] 2025-03-06 20:45:56,766 >> You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.
[INFO|modeling_utils.py:3982] 2025-03-06 20:45:58,354 >> loading weights file model.safetensors from cache at /home/foremans/.cache/huggingface/hub/models--argonne-private--AuroraGPT-7B/snapshots/ab1a9c95f3f913bb58570598f54e99bce68c2608/model.safetensors
[INFO|configuration_utils.py:1140] 2025-03-06 20:45:58,707 >> Generate config GenerationConfig {
"bos_token_id": 1,
"eos_token_id": 2
}
[[INFO|modeling_utils.py:4970] 2025-03-06 20:48:33,844 >> All model checkpoint weights were used when initializing LlamaForCausalLM.
[INFO|modeling_utils.py:4978] 2025-03-06 20:48:33,844 >> All the weights of LlamaForCausalLM were initialized from the model checkpoint at argonne-private/AuroraGPT-7B.
If your task is similar to the task the model of the checkpoint was trained on, you can already use LlamaForCausalLM for predictions without further training.
[INFO|modeling_utils.py:4366] 2025-03-06 20:48:33,903 >> Generation config file not found, using a generation config created from the model config.
Running tokenizer on dataset: 3%|β | 190000/6364518 [02:05<1:04:23, 1598.37 examples/s]We provide below a simple recipe for fine-tuning the base AuroraGPT-7B model on a single device3:
import ezpz
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
BATCH_SIZE = 2
rank = ezpz.setup_torch(backend='DDP')
torch.set_default_device(f'{ezpz.get_torch_device_type()}:{ezpz.get_local_rank()}')
dataset = load_dataset("HuggingFaceTB/finemath", "finemath-3plus")
dataset.set_format("torch", device="xpu")
model = AutoModelForCausalLM.from_pretrained(
'argonne-private/AuroraGPT-7B',
device_map={'':torch.get_default_device()},
torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained('argonne-private/AuroraGPT-7B')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
return tokenizer(
examples['text'],
truncation=True,
padding='max_length',
max_length=1,
return_tensors='pt'
)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
trainer = Trainer(
model=model,
train_dataset=tokenized_datasets['train'].shuffle(seed=42),
args=TrainingArguments(
optim='adafactor',
per_device_train_batch_size=BATCH_SIZE,
output_dir=f'outputs/training-{ezpz.get_timestamp()}',
),
)
# Outputs:
# [2025-03-04 21:47:36][I][ezpz/dist:790] Running on a single xpu, not initializing torch.distributed!
# [2025-03-04 21:47:36][I][ezpz/dist:845] Using device='xpu' with backend='DDP' + 'ccl' for distributed training.
# [2025-03-04 21:47:36][I][ezpz/dist:895] ['x4517c0s4b0n0'][0/0]
# 0%|β | 255/75000 [03:48<18:37:40, 1.11it/s]import ezpz
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
BATCH_SIZE = 1
rank = ezpz.setup_torch(backend='DDP')
torch.set_default_device(f'{ezpz.get_torch_device_type()}:{ezpz.get_local_rank()}')
dataset = load_dataset('stanfordnlp/imdb')
dataset.set_format("torch", device=ezpz.get_torch_device_type())
model = AutoModelForCausalLM.from_pretrained(
'argonne-private/AuroraGPT-7B',
device_map={'':torch.get_default_device()},
torch_dtype=torch.bfloat16
)
model.gradient_checkpointing_enable()
tokenizer = AutoTokenizer.from_pretrained('argonne-private/AuroraGPT-7B')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
return tokenizer(
examples['text'],
truncation=True,
padding='max_length',
max_length=1,
return_tensors='pt'
)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets['train'].shuffle(seed=42)
training_args = TrainingArguments(
optim='adamw_hf',
per_device_train_batch_size=BATCH_SIZE,
gradient_checkpointing=True,
output_dir=f'outputs/training-{ezpz.get_timestamp()}',
)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)import ezpz
import torch
from transformers import LlamaForCausalLM, AutoTokenizer
def prompt_model(model, tokenizer, prompt, max_length: int = 64) -> str:
with torch.autocast(
device_type=ezpz.get_torch_device_type(),
dtype=torch.bfloat16
):
return (
tokenizer.batch_decode(
model.generate(
**tokenizer(prompt, return_tensors='pt'),
max_length=max_length,
),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)[0]
)
model = AutoModelForCausalLM.from_pretrained('argonne-private/AuroraGPT-7B')
tokenizer = AutoTokenizer.from_pretrained('argonne-private/AuroraGPT-7B')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(
prompt_model(
model=model,
tokenizer=tokenizer,
prompt="How many 'r's are in the word 'strawberry?'",
max_length=64
)
)
# Outputs:
# How many 'r's are in the word'strawberry?'
# I'm not sure how many 'r's are in the word'strawberry' but I'm pretty sure it's 3.Footnotes
-
You should always be working in a virtual environment.
Seevenvβ Creation of virtual environments β Python 3 documentation β© -
The following instructions were tested and confirmed to be working on Aurora @ ALCF as of 2025-03-04. β©