Skip to content

Instantly share code, notes, and snippets.

@imohitmayank
Last active November 13, 2025 05:58
Show Gist options
  • Select an option

  • Save imohitmayank/cead7ad4a63c8770bbd5a8f48d25aeeb to your computer and use it in GitHub Desktop.

Select an option

Save imohitmayank/cead7ad4a63c8770bbd5a8f48d25aeeb to your computer and use it in GitHub Desktop.
Test Different LORA Adapter Saved Sizes
#!/usr/bin/env python3
"""
Script to test the size of saved models with different LoRA configurations:
1. Normal LoRA (attention layers only)
2. LoRA with embedding space adapter (train old embeddings)
3. LoRA with trainable_token_indices (specific token indices)
4. LoRA with new tokens added (train only new tokens)
5. LoRA with new tokens added (train entire embedding layer)
6. LoRA with new tokens added (train only new tokens + ensure_weight_tying=True)
Requirements:
pip install transformers peft torch
"""
import os
import shutil
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
try:
from peft import LoraConfig, get_peft_model, TaskType
except ImportError:
raise ImportError("PEFT library is required. Install it with: pip install peft")
def get_directory_size(directory_path):
"""Calculate the total size of a directory in bytes."""
total_size = 0
for dirpath, dirnames, filenames in os.walk(directory_path):
for filename in filenames:
filepath = os.path.join(dirpath, filename)
if os.path.exists(filepath):
total_size += os.path.getsize(filepath)
return total_size
def format_size(size_bytes):
"""Format bytes to human-readable format."""
for unit in ['B', 'KB', 'MB', 'GB']:
if size_bytes < 1024.0:
return f"{size_bytes:.2f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.2f} TB"
def print_model_info(model, tokenizer, save_path):
"""Print information about the model and its saved size."""
print("\n" + "="*60)
print(f"Model saved at: {save_path}")
print("="*60)
# Model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("\n📊 Model Statistics:")
print(f" Total parameters: {total_params:,}")
print(f" Trainable parameters: {trainable_params:,}")
print(f" Vocabulary size: {len(tokenizer):,}")
# File sizes
if os.path.exists(save_path):
dir_size = get_directory_size(save_path)
print("\n💾 Saved Model Size:")
print(f" Total size: {format_size(dir_size)}")
# List individual files
print("\n📁 Files in saved directory:")
for root, dirs, files in os.walk(save_path):
for file in files:
file_path = os.path.join(root, file)
file_size = os.path.getsize(file_path)
rel_path = os.path.relpath(file_path, save_path)
print(f" {rel_path}: {format_size(file_size)}")
# Print trainable parameters breakdown
print("\n🔧 Trainable Parameters Breakdown:")
trainable_count = 0
for name, param in model.named_parameters():
if param.requires_grad:
trainable_count += param.numel()
print(f" {name}: {param.shape} ({param.numel():,} params)")
print("\n" + "="*60)
def setup_normal_lora(model, tokenizer):
"""Setup normal LoRA on attention layers only."""
print("\n" + "="*60)
print("Setting up NORMAL LoRA (attention layers only)")
print("="*60)
lora_config = LoraConfig(
r=16, # LoRA attention dimension
lora_alpha=32, # Alpha parameter for LoRA scaling
lora_dropout=0.05, # Dropout probability for LoRA layers
bias="none", # Bias type for LoRA
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "v_proj"], # Standard attention layers
)
print("\nLoRA Configuration:")
print(f" r (rank): {lora_config.r}")
print(f" lora_alpha: {lora_config.lora_alpha}")
print(f" lora_dropout: {lora_config.lora_dropout}")
print(f" target_modules: {lora_config.target_modules}")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model, lora_config
def setup_embedding_lora(model, tokenizer):
"""Setup LoRA with embedding space adapter using modules_to_save."""
print("\n" + "="*60)
print("Setting up LoRA with EMBEDDING SPACE ADAPTER (train old embeddings)")
print("="*60)
# Get the original vocabulary size to identify new tokens
original_vocab_size = len(tokenizer)
# Option 1: Use modules_to_save to train embedding layer
# This trains the entire embedding layer
lora_config = LoraConfig(
r=16, # LoRA attention dimension
lora_alpha=32, # Alpha parameter for LoRA scaling
lora_dropout=0.05, # Dropout probability for LoRA layers
bias="none", # Bias type for LoRA
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "v_proj"], # Still target attention layers
modules_to_save=["embed_tokens"], # Also train the embedding layer
)
print("\nLoRA Configuration:")
print(f" r (rank): {lora_config.r}")
print(f" lora_alpha: {lora_config.lora_alpha}")
print(f" lora_dropout: {lora_config.lora_dropout}")
print(f" target_modules: {lora_config.target_modules}")
print(f" modules_to_save: {lora_config.modules_to_save}")
print(f" Original vocab size: {original_vocab_size:,}")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model, lora_config
def setup_embedding_lora_trainable_indices(model, tokenizer):
"""Setup LoRA with embedding space adapter using trainable_token_indices."""
print("\n" + "="*60)
print("Setting up LoRA with EMBEDDING SPACE ADAPTER (trainable_token_indices)")
print("="*60)
# Get the original vocabulary size to identify new tokens
vocab_size = len(tokenizer)
# Option 2: Use trainable_token_indices for specific tokens
# Train only a subset of tokens (e.g., last 10% or specific range)
# This is useful when you've added new tokens and only want to train those
# For demonstration, we'll train the last 10% of tokens
trainable_indices_start = int(vocab_size * 0.9) # Last 10% of tokens
trainable_indices = list(range(trainable_indices_start, vocab_size-1))
print(f"\n Vocabulary size: {vocab_size:,}")
print(f" Training tokens from index {trainable_indices_start} to {vocab_size-1}")
print(f" Number of trainable token indices: {len(trainable_indices):,}")
lora_config = LoraConfig(
r=16, # LoRA attention dimension
lora_alpha=32, # Alpha parameter for LoRA scaling
lora_dropout=0.05, # Dropout probability for LoRA layers
bias="none", # Bias type for LoRA
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "v_proj"], # Still target attention layers
trainable_token_indices={"embed_tokens": trainable_indices}, # Train specific token indices
)
print("\nLoRA Configuration:")
print(f" r (rank): {lora_config.r}")
print(f" lora_alpha: {lora_config.lora_alpha}")
print(f" lora_dropout: {lora_config.lora_dropout}")
print(f" target_modules: {lora_config.target_modules}")
print(f" trainable_token_indices: {len(trainable_indices):,} tokens")
print(f" (indices {trainable_indices_start} to {vocab_size-1})")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Store trainable_indices for later use in summary
lora_config.trainable_indices_list = trainable_indices
return model, lora_config
def add_new_tokens_to_model(model, tokenizer):
"""Helper function to add new tokens to tokenizer and resize model embeddings."""
# Store original vocabulary size
original_vocab_size = len(tokenizer)
# Add new tokens to tokenizer
# For demonstration, add some example tokens (e.g., special tokens for audio codec)
new_tokens = [
"<audio_start>", "<audio_end>", "<layer_sep>",
"<snac_pad>", "<snac_token_1>", "<snac_token_2>",
"<snac_token_3>", "<snac_token_4>", "<snac_token_5>",
"<snac_token_6>", "<snac_token_7>", "<snac_token_8>",
"<snac_token_9>", "<snac_token_10>", "<snac_token_11>",
"<snac_token_12>", "<snac_token_13>", "<snac_token_14>",
"<snac_token_15>", "<snac_token_16>", "<snac_token_17>",
"<snac_token_18>", "<snac_token_19>", "<snac_token_20>",
]
# Add tokens to tokenizer
num_added = tokenizer.add_tokens(new_tokens)
# Resize model embeddings to accommodate new tokens
model.resize_token_embeddings(len(tokenizer))
# Get the indices of the newly added tokens
new_token_indices = list(range(original_vocab_size, len(tokenizer)))
return original_vocab_size, num_added, new_token_indices
def setup_new_tokens_lora(model, tokenizer):
"""Setup LoRA with new tokens added to tokenizer and model, train only new tokens."""
print("\n" + "="*60)
print("Setting up LoRA with NEW TOKENS (train only new tokens)")
print("="*60)
# Add new tokens using helper function
original_vocab_size, num_added, new_token_indices = add_new_tokens_to_model(model, tokenizer)
print(f"\n Original vocabulary size: {original_vocab_size:,}")
print(f" Added {num_added} new tokens to tokenizer")
print(f" New vocabulary size: {len(tokenizer):,}")
print(f" Resized model embeddings to {len(tokenizer):,} tokens")
print(f" New token indices: {original_vocab_size} to {len(tokenizer)-1}")
print(f" Number of new token indices: {len(new_token_indices):,}")
# Setup LoRA with trainable_token_indices for only the new tokens
lora_config = LoraConfig(
r=16, # LoRA attention dimension
lora_alpha=32, # Alpha parameter for LoRA scaling
lora_dropout=0.05, # Dropout probability for LoRA layers
bias="none", # Bias type for LoRA
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "v_proj"], # Still target attention layers
trainable_token_indices={"embed_tokens": new_token_indices}, # Train only new tokens
)
print("\nLoRA Configuration:")
print(f" r (rank): {lora_config.r}")
print(f" lora_alpha: {lora_config.lora_alpha}")
print(f" lora_dropout: {lora_config.lora_dropout}")
print(f" target_modules: {lora_config.target_modules}")
print(f" trainable_token_indices: {len(new_token_indices):,} new tokens")
print(f" (indices {original_vocab_size} to {len(tokenizer)-1})")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Store new_token_indices for later use in summary
lora_config.new_token_indices_list = new_token_indices
lora_config.original_vocab_size = original_vocab_size
lora_config.num_new_tokens = num_added
return model, lora_config
def setup_new_tokens_full_embedding_lora(model, tokenizer):
"""Setup LoRA with new tokens added to tokenizer and model, train entire embedding layer."""
print("\n" + "="*60)
print("Setting up LoRA with NEW TOKENS (train entire embedding layer)")
print("="*60)
# Add new tokens using helper function
original_vocab_size, num_added, new_token_indices = add_new_tokens_to_model(model, tokenizer)
print(f"\n Original vocabulary size: {original_vocab_size:,}")
print(f" Added {num_added} new tokens to tokenizer")
print(f" New vocabulary size: {len(tokenizer):,}")
print(f" Resized model embeddings to {len(tokenizer):,} tokens")
print(f" New token indices: {original_vocab_size} to {len(tokenizer)-1}")
print(f" Number of new token indices: {len(new_token_indices):,}")
# Setup LoRA with modules_to_save to train the entire embedding layer
lora_config = LoraConfig(
r=16, # LoRA attention dimension
lora_alpha=32, # Alpha parameter for LoRA scaling
lora_dropout=0.05, # Dropout probability for LoRA layers
bias="none", # Bias type for LoRA
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "v_proj"], # Still target attention layers
modules_to_save=["embed_tokens"], # Train entire embedding layer (including new tokens)
)
print("\nLoRA Configuration:")
print(f" r (rank): {lora_config.r}")
print(f" lora_alpha: {lora_config.lora_alpha}")
print(f" lora_dropout: {lora_config.lora_dropout}")
print(f" target_modules: {lora_config.target_modules}")
print(f" modules_to_save: {lora_config.modules_to_save}")
print(f" Training entire embedding layer ({len(tokenizer):,} tokens total)")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Store metadata for later use in summary
lora_config.new_token_indices_list = new_token_indices
lora_config.original_vocab_size = original_vocab_size
lora_config.num_new_tokens = num_added
return model, lora_config
def setup_embedding_lora_with_weight_tying(model, tokenizer):
"""Setup LoRA with new tokens added to tokenizer and model, train only new tokens with ensure_weight_tying=True."""
print("\n" + "="*60)
print("Setting up LoRA with NEW TOKENS (train only new tokens + ensure_weight_tying)")
print("="*60)
# Add new tokens using helper function
original_vocab_size, num_added, new_token_indices = add_new_tokens_to_model(model, tokenizer)
print(f"\n Original vocabulary size: {original_vocab_size:,}")
print(f" Added {num_added} new tokens to tokenizer")
print(f" New vocabulary size: {len(tokenizer):,}")
print(f" Resized model embeddings to {len(tokenizer):,} tokens")
print(f" New token indices: {original_vocab_size} to {len(tokenizer)-1}")
print(f" Number of new token indices: {len(new_token_indices):,}")
# Setup LoRA with trainable_token_indices for only the new tokens
# Add ensure_weight_tying=True to keep weight tying consistent and mark embedding layer as trainable
lora_config = LoraConfig(
r=16, # LoRA attention dimension
lora_alpha=32, # Alpha parameter for LoRA scaling
lora_dropout=0.05, # Dropout probability for LoRA layers
bias="none", # Bias type for LoRA
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "v_proj"], # Still target attention layers
trainable_token_indices={"embed_tokens": new_token_indices}, # Train only new tokens
ensure_weight_tying=True, # Keep weight tying consistent and mark embedding layer as trainable
)
print("\nLoRA Configuration:")
print(f" r (rank): {lora_config.r}")
print(f" lora_alpha: {lora_config.lora_alpha}")
print(f" lora_dropout: {lora_config.lora_dropout}")
print(f" target_modules: {lora_config.target_modules}")
print(f" trainable_token_indices: {len(new_token_indices):,} new tokens")
print(f" (indices {original_vocab_size} to {len(tokenizer)-1})")
print(f" ensure_weight_tying: {lora_config.ensure_weight_tying}")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Store new_token_indices for later use in summary
lora_config.new_token_indices_list = new_token_indices
lora_config.original_vocab_size = original_vocab_size
lora_config.num_new_tokens = num_added
return model, lora_config
def main():
"""Main function to test model sizes."""
print("="*60)
print("MODEL SIZE TESTING SCRIPT")
print("="*60)
# Configuration
# Try different gemma-3 variants, fallback to gemma-2
model_candidates = [
"google/gemma-3-270m",
# "google/gemma-3-2b",
# "google/gemma-2-2b",
# "google/gemma-2-9b",
]
base_output_dir = "./test_model_sizes"
# Try to load gemma-3 (or fallback to gemma-2)
model_name_used = None
model = None
tokenizer = None
for candidate in model_candidates:
try:
print(f"\n🔍 Attempting to load: {candidate}")
model = AutoModelForCausalLM.from_pretrained(
candidate,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="cpu"
)
tokenizer = AutoTokenizer.from_pretrained(
candidate,
trust_remote_code=True
)
model_name_used = candidate
print(f"✅ Successfully loaded: {candidate}")
break
except Exception as e:
print(f"⚠️ Could not load {candidate}: {e}")
continue
if model is None or tokenizer is None:
raise RuntimeError(f"Failed to load any model from candidates: {model_candidates}")
# Add padding token if not present
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
print(f"\n✅ Loaded model: {model_name_used}")
print(f" Vocabulary size: {len(tokenizer):,}")
# Clean and remove existing output directory if it exists
if os.path.exists(base_output_dir):
print(f"\n🧹 Cleaning existing output directory: {base_output_dir}")
shutil.rmtree(base_output_dir)
print("✅ Removed existing directory")
# Create output directory
os.makedirs(base_output_dir, exist_ok=True)
# Test 1: Normal LoRA
print("\n" + "="*60)
print("TEST 1: NORMAL LoRA (Attention Layers Only)")
print("="*60)
# Load fresh model for test 1
model1 = AutoModelForCausalLM.from_pretrained(
model_name_used,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="cpu"
)
tokenizer1 = AutoTokenizer.from_pretrained(
model_name_used,
trust_remote_code=True
)
if tokenizer1.pad_token is None:
tokenizer1.pad_token = tokenizer1.eos_token
tokenizer1.pad_token_id = tokenizer1.eos_token_id
model1, lora_config1 = setup_normal_lora(model1, tokenizer1)
save_path1 = os.path.join(base_output_dir, "normal_lora")
print(f"\n💾 Saving normal LoRA model to: {save_path1}")
model1.save_pretrained(save_path1)
tokenizer1.save_pretrained(save_path1)
print_model_info(model1, tokenizer1, save_path1)
# Test 2: LoRA with Embedding Adapter
print("\n" + "="*60)
print("TEST 2: LoRA with EMBEDDING SPACE ADAPTER")
print("="*60)
# Load fresh model for test 2
model2 = AutoModelForCausalLM.from_pretrained(
model_name_used,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="cpu"
)
tokenizer2 = AutoTokenizer.from_pretrained(
model_name_used,
trust_remote_code=True
)
if tokenizer2.pad_token is None:
tokenizer2.pad_token = tokenizer2.eos_token
tokenizer2.pad_token_id = tokenizer2.eos_token_id
model2, lora_config2 = setup_embedding_lora(model2, tokenizer2)
save_path2 = os.path.join(base_output_dir, "embedding_lora")
print(f"\n💾 Saving embedding LoRA model to: {save_path2}")
model2.save_pretrained(save_path2)
tokenizer2.save_pretrained(save_path2)
print_model_info(model2, tokenizer2, save_path2)
# Test 3: LoRA with Trainable Token Indices
print("\n" + "="*60)
print("TEST 3: LoRA with TRAINABLE TOKEN INDICES")
print("="*60)
# Load fresh model for test 3
model3 = AutoModelForCausalLM.from_pretrained(
model_name_used,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="cpu"
)
tokenizer3 = AutoTokenizer.from_pretrained(
model_name_used,
trust_remote_code=True
)
if tokenizer3.pad_token is None:
tokenizer3.pad_token = tokenizer3.eos_token
tokenizer3.pad_token_id = tokenizer3.eos_token_id
model3, lora_config3 = setup_embedding_lora_trainable_indices(model3, tokenizer3)
save_path3 = os.path.join(base_output_dir, "trainable_indices_lora")
print(f"\n💾 Saving trainable indices LoRA model to: {save_path3}")
model3.save_pretrained(save_path3)
tokenizer3.save_pretrained(save_path3)
print_model_info(model3, tokenizer3, save_path3)
# Test 4: LoRA with New Tokens Added
print("\n" + "="*60)
print("TEST 4: LoRA with NEW TOKENS ADDED")
print("="*60)
# Load fresh model for test 4
model4 = AutoModelForCausalLM.from_pretrained(
model_name_used,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="cpu"
)
tokenizer4 = AutoTokenizer.from_pretrained(
model_name_used,
trust_remote_code=True
)
if tokenizer4.pad_token is None:
tokenizer4.pad_token = tokenizer4.eos_token
tokenizer4.pad_token_id = tokenizer4.eos_token_id
model4, lora_config4 = setup_new_tokens_lora(model4, tokenizer4)
save_path4 = os.path.join(base_output_dir, "new_tokens_lora")
print(f"\n💾 Saving new tokens LoRA model to: {save_path4}")
model4.save_pretrained(save_path4)
tokenizer4.save_pretrained(save_path4)
print_model_info(model4, tokenizer4, save_path4)
# Test 5: LoRA with New Tokens Added (Full Embedding)
print("\n" + "="*60)
print("TEST 5: LoRA with NEW TOKENS ADDED (Full Embedding)")
print("="*60)
# Load fresh model for test 5
model5 = AutoModelForCausalLM.from_pretrained(
model_name_used,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="cpu"
)
tokenizer5 = AutoTokenizer.from_pretrained(
model_name_used,
trust_remote_code=True
)
if tokenizer5.pad_token is None:
tokenizer5.pad_token = tokenizer5.eos_token
tokenizer5.pad_token_id = tokenizer5.eos_token_id
model5, lora_config5 = setup_new_tokens_full_embedding_lora(model5, tokenizer5)
save_path5 = os.path.join(base_output_dir, "new_tokens_full_embedding_lora")
print(f"\n💾 Saving new tokens full embedding LoRA model to: {save_path5}")
model5.save_pretrained(save_path5)
tokenizer5.save_pretrained(save_path5)
print_model_info(model5, tokenizer5, save_path5)
# Test 6: LoRA with New Tokens (Weight Tying)
print("\n" + "="*60)
print("TEST 6: LoRA with NEW TOKENS (Weight Tying)")
print("="*60)
# Load fresh model for test 6
model6 = AutoModelForCausalLM.from_pretrained(
model_name_used,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="cpu"
)
tokenizer6 = AutoTokenizer.from_pretrained(
model_name_used,
trust_remote_code=True
)
if tokenizer6.pad_token is None:
tokenizer6.pad_token = tokenizer6.eos_token
tokenizer6.pad_token_id = tokenizer6.eos_token_id
model6, lora_config6 = setup_embedding_lora_with_weight_tying(model6, tokenizer6)
save_path6 = os.path.join(base_output_dir, "new_tokens_lora_weight_tying")
print(f"\n💾 Saving new tokens LoRA with weight tying model to: {save_path6}")
model6.save_pretrained(save_path6)
tokenizer6.save_pretrained(save_path6)
print_model_info(model6, tokenizer6, save_path6)
# Comparison
print("\n" + "="*60)
print("SIZE COMPARISON")
print("="*60)
size1 = get_directory_size(save_path1)
size2 = get_directory_size(save_path2)
size3 = get_directory_size(save_path3)
size4 = get_directory_size(save_path4)
size5 = get_directory_size(save_path5)
size6 = get_directory_size(save_path6)
print("\n📊 Model Size Comparison:")
print(f" Normal LoRA: {format_size(size1)}")
print(f" Embedding LoRA (train old embeddings): {format_size(size2)}")
print(f" Embedding LoRA (trainable_indices): {format_size(size3)}")
print(f" New Tokens LoRA (train new only): {format_size(size4)}")
print(f" New Tokens LoRA (train old + new embedding): {format_size(size5)}")
print(f" New Tokens LoRA (train new only + weight tying): {format_size(size6)}")
sizes = [
("Normal LoRA", size1),
("Embedding LoRA (train old embeddings)", size2),
("Embedding LoRA (trainable_indices)", size3),
("New Tokens LoRA (train new only)", size4),
("New Tokens LoRA (train old + new embedding)", size5),
("New Tokens LoRA (train new only + weight tying)", size6)
]
sizes_sorted = sorted(sizes, key=lambda x: x[1])
print("\n📈 Size Ranking (smallest to largest):")
for i, (name, size) in enumerate(sizes_sorted, 1):
print(f" {i}. {name}: {format_size(size)}")
if size2 > size1:
print(f"\n Embedding LoRA (train old embeddings) is {size2/size1:.2f}x larger than Normal LoRA")
if size3 > size1:
print(f" Embedding LoRA (trainable_indices) is {size3/size1:.2f}x larger than Normal LoRA")
if size4 > size1:
print(f" New Tokens LoRA (train new only) is {size4/size1:.2f}x larger than Normal LoRA")
if size5 > size1:
print(f" New Tokens LoRA (train old + new embedding) is {size5/size1:.2f}x larger than Normal LoRA")
if size6 > size1:
print(f" New Tokens LoRA (train new only + weight tying) is {size6/size1:.2f}x larger than Normal LoRA")
if size3 > size2:
print(f" Embedding LoRA (trainable_indices) is {size3/size2:.2f}x larger than modules_to_save")
elif size2 > size3:
print(f" Embedding LoRA (train old embeddings) is {size2/size3:.2f}x larger than trainable_indices")
if size4 > size3:
print(f" New Tokens LoRA (train new only) is {size4/size3:.2f}x larger than trainable_indices")
elif size3 > size4:
print(f" Embedding LoRA (trainable_indices) is {size3/size4:.2f}x larger than New Tokens LoRA (train new only)")
if size6 > size4:
print(f" New Tokens LoRA (train new only + weight tying) is {size6/size4:.2f}x larger than New Tokens LoRA (train new only)")
elif size4 > size6:
print(f" New Tokens LoRA (train new only) is {size4/size6:.2f}x larger than New Tokens LoRA (train new only + weight tying)")
if size5 > size4:
print(f" New Tokens LoRA (train old + new embedding) is {size5/size4:.2f}x larger than New Tokens LoRA (train new only)")
elif size4 > size5:
print(f" New Tokens LoRA (train new only) is {size4/size5:.2f}x larger than New Tokens LoRA (train old + new embedding)")
if size5 > size2:
print(f" New Tokens LoRA (train old + new embedding) is {size5/size2:.2f}x larger than Embedding LoRA (train old embeddings)")
elif size2 > size5:
print(f" Embedding LoRA (train old embeddings) is {size2/size5:.2f}x larger than New Tokens LoRA (train old + new embedding)")
if size6 > size5:
print(f" New Tokens LoRA (train new only + weight tying) is {size6/size5:.2f}x larger than New Tokens LoRA (train old + new embedding)")
elif size5 > size6:
print(f" New Tokens LoRA (train old + new embedding) is {size5/size6:.2f}x larger than New Tokens LoRA (train new only + weight tying)")
print("\n" + "="*60)
print("TESTING COMPLETE")
print("="*60)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment