Skip to content

Instantly share code, notes, and snippets.

@BexTuychiev
Created November 20, 2025 11:49
Show Gist options
  • Select an option

  • Save BexTuychiev/8c6db50604d58be6fc5a5e7bf290806a to your computer and use it in GitHub Desktop.

Select an option

Save BexTuychiev/8c6db50604d58be6fc5a5e7bf290806a to your computer and use it in GitHub Desktop.
"""
Tinker Financial Q&A Fine-Tuning with FinCoT Dataset
Uses chain-of-thought reasoning dataset for improved answer quality
Includes validation tracking, warmup, and proper checkpoint management
"""
import time
import numpy as np
from dotenv import load_dotenv
from datasets import load_dataset
# Load environment variables
load_dotenv()
# Import Tinker
import tinker
from tinker import types
# Retry wrapper for API calls
def with_retry(future, max_attempts=3, delay=5):
"""Simple retry logic for API futures"""
for attempt in range(max_attempts):
try:
return future.result()
except Exception as e:
if attempt == max_attempts - 1:
raise
print(
f" ⚠ API error (attempt {attempt + 1}/{max_attempts}): {str(e)[:50]}... Retrying in {delay}s"
)
time.sleep(delay)
return None
print("=" * 60)
print("Tinker FinCoT Fine-Tuning (Chain-of-Thought)")
print("=" * 60)
# ============================================================================
# 1. SETUP & INITIALIZATION
# ============================================================================
print("\n[1/7] Initializing Tinker ServiceClient...")
service_client = tinker.ServiceClient()
print("βœ“ ServiceClient initialized")
# ============================================================================
# 2. DATASET PREPARATION WITH TRAIN/VAL SPLIT
# ============================================================================
print("\n[2/7] Loading FinCoT dataset...")
dataset = load_dataset("TheFinAI/FinCoT")
# Split into train and validation
train_data_raw = dataset["SFT"] # 7,690 examples with reasoning chains
val_data_raw = dataset["RL"].shuffle(seed=42).select(range(500)) # 500 for validation
print(f"βœ“ Loaded {len(train_data_raw)} training examples (SFT split)")
print(f"βœ“ Loaded {len(val_data_raw)} validation examples (from RL split)")
# ============================================================================
# 3. TRAINING CLIENT SETUP
# ============================================================================
print("\n[3/7] Creating LoRA training client for Qwen3-8B...")
training_client = service_client.create_lora_training_client(
base_model="Qwen/Qwen3-8B",
rank=32, # LoRA rank parameter
)
print("βœ“ Training client created")
# Get tokenizer
tokenizer = training_client.get_tokenizer()
print("βœ“ Tokenizer loaded")
# Prepare training data with FinCoT format
def prepare_datum(example, max_length=10000):
"""Convert FinCoT example with reasoning chain to types.Datum
FinCoT format includes (note: field names are capitalized):
- Question: The financial question
- Reasoning_process: Step-by-step chain-of-thought explanation
- Final_response: The concise final answer
Args:
example: FinCoT dataset example
max_length: Maximum sequence length (default 10000 for Qwen3-8B)
Returns:
types.Datum or None if sequence is too long
"""
# Qwen3 chat format - split into observation and action parts
# User message
user_ob = "<|im_start|>user\n"
user_ac = f"{example['Question']}<|im_end|>"
# Assistant message (newline prefix since it's the 2nd message)
assistant_ob = "\n<|im_start|>assistant\n"
# IMPORTANT: Include both reasoning process AND final answer
# This teaches the model to show its work before giving the answer
assistant_ac = f"{example['Reasoning_process']}\n\nFinal Answer: {example['Final_response']}<|im_end|>"
# Tokenize each part separately to track weight boundaries
user_ob_tokens = tokenizer.encode(user_ob, add_special_tokens=False)
user_ac_tokens = tokenizer.encode(user_ac, add_special_tokens=False)
assistant_ob_tokens = tokenizer.encode(assistant_ob, add_special_tokens=False)
assistant_ac_tokens = tokenizer.encode(assistant_ac, add_special_tokens=False)
# Combine all tokens
all_tokens = (
user_ob_tokens + user_ac_tokens + assistant_ob_tokens + assistant_ac_tokens
)
# Check if sequence exceeds max length (Qwen3-8B limit is 32,768)
if len(all_tokens) > max_length:
return None # Skip this example
# Weights: only train on assistant's answer (action part)
weights = np.array(
[0.0] * len(user_ob_tokens)
+ [0.0] * len(user_ac_tokens)
+ [0.0] * len(assistant_ob_tokens)
+ [1.0] * len(assistant_ac_tokens)
)
# CRITICAL: Shift tokens AND weights for next-token prediction
input_tokens_model = all_tokens[:-1]
target_tokens = all_tokens[1:]
weights_shifted = weights[1:]
return types.Datum(
model_input=types.ModelInput.from_ints(tokens=input_tokens_model),
loss_fn_inputs=dict(weights=weights_shifted, target_tokens=target_tokens),
)
def compute_validation_loss(val_data, batch_size=100):
"""Compute loss on validation set (forward only, no backward)"""
# Sample a batch from validation set
batch_indices = np.random.choice(
len(val_data), size=min(batch_size, len(val_data)), replace=False
)
batch = [val_data[i] for i in batch_indices]
# Forward pass only (no backward!)
fwd_future = training_client.forward(batch, loss_fn="cross_entropy")
fwd_result = with_retry(fwd_future)
# Calculate per-token loss
loss_sum = fwd_result.metrics["loss:sum"]
total_completion_tokens = sum(
np.sum(np.array(val_data[i].loss_fn_inputs["weights"].data) > 0)
for i in batch_indices
)
per_token_loss = (
loss_sum / total_completion_tokens if total_completion_tokens > 0 else 0
)
return per_token_loss
print("\n[4/7] Processing training and validation data...")
print(" Filtering examples that exceed max sequence length (32,000 tokens)...")
# Process and filter training data
training_data_raw_processed = [prepare_datum(example) for example in train_data_raw]
training_data = [d for d in training_data_raw_processed if d is not None]
train_skipped = len(train_data_raw) - len(training_data)
# Process and filter validation data
val_data_raw_processed = [prepare_datum(example) for example in val_data_raw]
val_data = [d for d in val_data_raw_processed if d is not None]
val_skipped = len(val_data_raw) - len(val_data)
print(
f"βœ“ Processed {len(training_data)} train examples (skipped {train_skipped} too-long)"
)
print(
f"βœ“ Processed {len(val_data)} validation examples (skipped {val_skipped} too-long)"
)
# ============================================================================
# 5. TRAINING LOOP WITH VALIDATION TRACKING
# ============================================================================
print("\n[5/7] Starting training loop...")
print("-" * 60)
# Training configuration
n_samples = len(training_data) # 7,690
n_epochs = 4 # Increased for smaller dataset
batch_size = 32
# Calculate optimal LR for Qwen3-8B with LoRA (from hyperparam_utils formula)
# Formula: base_lr * lora_multiplier * (2000 / hidden_size) ** exponent
# For Qwen: base=5e-5, multiplier=10, hidden_size=4096, exponent=0.0775
learning_rate = 5e-5 * 10.0 * (2000 / 4096) ** 0.0775 # β‰ˆ 3.7e-4
print(f"Calculated learning rate: {learning_rate:.6f}")
# Warmup configuration
warmup_steps = 200
print(f"Warmup steps: {warmup_steps}")
# Calculate iterations from epochs
num_iterations = n_epochs * (n_samples // batch_size) # 4 * 240 = ~961
checkpoint_interval = 200
validation_interval = 50
print(f"Total iterations: {num_iterations}")
print(f"Checkpoints every: {checkpoint_interval} iterations")
print(f"Validation every: {validation_interval} iterations")
losses = []
per_token_losses = []
val_losses = [] # Store (iteration, val_loss) tuples
for iteration in range(num_iterations):
# Sample random batch
batch_indices = np.random.choice(len(training_data), size=batch_size, replace=False)
batch = [training_data[i] for i in batch_indices]
# Apply learning rate warmup
if iteration < warmup_steps:
current_lr = learning_rate * (iteration + 1) / warmup_steps
else:
current_lr = learning_rate
# API Primitive 1: forward_backward - compute gradients
fwdbwd_future = training_client.forward_backward(batch, loss_fn="cross_entropy")
# API Primitive 2: optim_step - update parameters
optim_future = training_client.optim_step(
types.AdamParams(learning_rate=current_lr)
)
# Wait for results with retry logic
fwdbwd_result = with_retry(fwdbwd_future)
optim_result = with_retry(optim_future)
# Track loss (from metrics)
loss_sum = fwdbwd_result.metrics["loss:sum"]
# Calculate per-token loss (count tokens with weight > 0)
total_completion_tokens = sum(
np.sum(np.array(training_data[i].loss_fn_inputs["weights"].data) > 0)
for i in batch_indices
)
per_token_loss = (
loss_sum / total_completion_tokens if total_completion_tokens > 0 else 0
)
losses.append(loss_sum)
per_token_losses.append(per_token_loss)
# Print progress every 10 iterations
if iteration % 10 == 0:
warmup_indicator = "πŸ”₯" if iteration < warmup_steps else " "
print(
f"{warmup_indicator} Iteration {iteration:4d} | Train Loss: {per_token_loss:.4f} | LR: {current_lr:.6f}"
)
# Compute validation loss periodically
if iteration % validation_interval == 0:
val_loss = compute_validation_loss(val_data)
val_losses.append((iteration, val_loss))
gap = val_loss - per_token_loss
gap_indicator = "⚠️" if gap > 0.2 else "βœ“"
print(
f" πŸ“Š {gap_indicator} Iteration {iteration:4d} | Train: {per_token_loss:.4f} | Val: {val_loss:.4f} | Gap: {gap:+.4f}"
)
# Save checkpoints every N iterations
if iteration > 0 and iteration % checkpoint_interval == 0:
print(f" πŸ’Ύ Saving checkpoint at iteration {iteration}...")
# Save full state (for resuming training)
state_result = training_client.save_state(name=f"fincot-full-state-{iteration}")
state_path = with_retry(state_result).path
print(f" βœ“ Full state: {state_path}")
# Also save sampler weights (for inference)
sampler_result = training_client.save_weights_for_sampler(
name=f"fincot-checkpoint-{iteration}"
)
sampler_path = with_retry(sampler_result).path
print(f" βœ“ Sampler weights: {sampler_path}")
print("-" * 60)
print(f"βœ“ Training complete!")
# Calculate final metrics
final_train_loss = per_token_losses[-1]
final_val_loss = compute_validation_loss(
val_data, batch_size=200
) # Larger batch for final eval
print(f"\nFinal Results:")
print(f" Training Loss: {final_train_loss:.4f}")
print(f" Validation Loss: {final_val_loss:.4f}")
print(f" Train/Val Gap: {final_val_loss - final_train_loss:+.4f}")
if final_val_loss - final_train_loss > 0.2:
print(" ⚠️ Warning: Large train/val gap suggests overfitting")
elif abs(final_val_loss - final_train_loss) < 0.1:
print(" βœ… Good: Small train/val gap suggests healthy generalization")
else:
print(" βœ“ Moderate: Train/val gap is acceptable")
# ============================================================================
# 6. SAVE FINAL MODEL
# ============================================================================
print("\n[6/7] Saving final model...")
# Save final checkpoint (both types)
final_state_result = training_client.save_state(name="fincot-final")
final_state_path = with_retry(final_state_result).path
print(f"βœ“ Final full state saved: {final_state_path}")
# API Primitive 3: save_state - persist model weights for sampling
sampling_client = training_client.save_weights_and_get_sampling_client(
name="fincot-qwen3-8b-lora-final"
)
print("βœ“ Final sampler weights saved")
# ============================================================================
# 7. TEST WITH SAMPLE QUESTIONS
# ============================================================================
print("\n[7/7] Testing model with sample questions...")
# Configure sampling parameters (increased max_tokens for reasoning)
sampling_params = types.SamplingParams(
max_tokens=400, # Increased to allow for reasoning chains
temperature=0.7,
top_p=0.9,
stop=["<|im_end|>"], # Qwen3 stop token
)
# Test questions
test_questions = [
"What are the main risks associated with investing in stocks?",
"How does diversification help reduce portfolio risk?",
"What is the difference between a stock and a bond?",
]
print("\n" + "=" * 60)
print("SAMPLE OUTPUTS (with reasoning chains)")
print("=" * 60)
for i, question in enumerate(test_questions, 1):
# Format with Qwen3 chat template
prompt_text = f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
model_input = types.ModelInput.from_ints(tokenizer.encode(prompt_text))
# API Primitive 4: sample - generate predictions (with retry)
response = with_retry(
sampling_client.sample(
prompt=model_input, num_samples=1, sampling_params=sampling_params
)
)
# Decode response (sequences[0].tokens)
answer = tokenizer.decode(response.sequences[0].tokens)
print(f"\nQ{i}: {question}")
print(f"A{i}: {answer[:300]}{'...' if len(answer) > 300 else ''}")
print("-" * 60)
print("\nβœ… All 4 API primitives demonstrated successfully!")
print(" 1. forward_backward βœ“")
print(" 2. optim_step βœ“")
print(" 3. save_state βœ“")
print(" 4. sample βœ“")
print("\nπŸ“Š Training Summary:")
print(f" Dataset: FinCoT (7,690 train + 500 val)")
print(f" Epochs: {n_epochs}")
print(f" Total iterations: {num_iterations}")
print(f" Initial loss: {per_token_losses[0]:.4f}")
print(f" Final train loss: {final_train_loss:.4f}")
print(f" Final val loss: {final_val_loss:.4f}")
print(f" Improvement: {per_token_losses[0] - final_train_loss:.4f}")
print("\n" + "=" * 60)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment