Created
November 20, 2025 11:49
-
-
Save BexTuychiev/8c6db50604d58be6fc5a5e7bf290806a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Tinker Financial Q&A Fine-Tuning with FinCoT Dataset | |
| Uses chain-of-thought reasoning dataset for improved answer quality | |
| Includes validation tracking, warmup, and proper checkpoint management | |
| """ | |
| import time | |
| import numpy as np | |
| from dotenv import load_dotenv | |
| from datasets import load_dataset | |
| # Load environment variables | |
| load_dotenv() | |
| # Import Tinker | |
| import tinker | |
| from tinker import types | |
| # Retry wrapper for API calls | |
| def with_retry(future, max_attempts=3, delay=5): | |
| """Simple retry logic for API futures""" | |
| for attempt in range(max_attempts): | |
| try: | |
| return future.result() | |
| except Exception as e: | |
| if attempt == max_attempts - 1: | |
| raise | |
| print( | |
| f" β API error (attempt {attempt + 1}/{max_attempts}): {str(e)[:50]}... Retrying in {delay}s" | |
| ) | |
| time.sleep(delay) | |
| return None | |
| print("=" * 60) | |
| print("Tinker FinCoT Fine-Tuning (Chain-of-Thought)") | |
| print("=" * 60) | |
| # ============================================================================ | |
| # 1. SETUP & INITIALIZATION | |
| # ============================================================================ | |
| print("\n[1/7] Initializing Tinker ServiceClient...") | |
| service_client = tinker.ServiceClient() | |
| print("β ServiceClient initialized") | |
| # ============================================================================ | |
| # 2. DATASET PREPARATION WITH TRAIN/VAL SPLIT | |
| # ============================================================================ | |
| print("\n[2/7] Loading FinCoT dataset...") | |
| dataset = load_dataset("TheFinAI/FinCoT") | |
| # Split into train and validation | |
| train_data_raw = dataset["SFT"] # 7,690 examples with reasoning chains | |
| val_data_raw = dataset["RL"].shuffle(seed=42).select(range(500)) # 500 for validation | |
| print(f"β Loaded {len(train_data_raw)} training examples (SFT split)") | |
| print(f"β Loaded {len(val_data_raw)} validation examples (from RL split)") | |
| # ============================================================================ | |
| # 3. TRAINING CLIENT SETUP | |
| # ============================================================================ | |
| print("\n[3/7] Creating LoRA training client for Qwen3-8B...") | |
| training_client = service_client.create_lora_training_client( | |
| base_model="Qwen/Qwen3-8B", | |
| rank=32, # LoRA rank parameter | |
| ) | |
| print("β Training client created") | |
| # Get tokenizer | |
| tokenizer = training_client.get_tokenizer() | |
| print("β Tokenizer loaded") | |
| # Prepare training data with FinCoT format | |
| def prepare_datum(example, max_length=10000): | |
| """Convert FinCoT example with reasoning chain to types.Datum | |
| FinCoT format includes (note: field names are capitalized): | |
| - Question: The financial question | |
| - Reasoning_process: Step-by-step chain-of-thought explanation | |
| - Final_response: The concise final answer | |
| Args: | |
| example: FinCoT dataset example | |
| max_length: Maximum sequence length (default 10000 for Qwen3-8B) | |
| Returns: | |
| types.Datum or None if sequence is too long | |
| """ | |
| # Qwen3 chat format - split into observation and action parts | |
| # User message | |
| user_ob = "<|im_start|>user\n" | |
| user_ac = f"{example['Question']}<|im_end|>" | |
| # Assistant message (newline prefix since it's the 2nd message) | |
| assistant_ob = "\n<|im_start|>assistant\n" | |
| # IMPORTANT: Include both reasoning process AND final answer | |
| # This teaches the model to show its work before giving the answer | |
| assistant_ac = f"{example['Reasoning_process']}\n\nFinal Answer: {example['Final_response']}<|im_end|>" | |
| # Tokenize each part separately to track weight boundaries | |
| user_ob_tokens = tokenizer.encode(user_ob, add_special_tokens=False) | |
| user_ac_tokens = tokenizer.encode(user_ac, add_special_tokens=False) | |
| assistant_ob_tokens = tokenizer.encode(assistant_ob, add_special_tokens=False) | |
| assistant_ac_tokens = tokenizer.encode(assistant_ac, add_special_tokens=False) | |
| # Combine all tokens | |
| all_tokens = ( | |
| user_ob_tokens + user_ac_tokens + assistant_ob_tokens + assistant_ac_tokens | |
| ) | |
| # Check if sequence exceeds max length (Qwen3-8B limit is 32,768) | |
| if len(all_tokens) > max_length: | |
| return None # Skip this example | |
| # Weights: only train on assistant's answer (action part) | |
| weights = np.array( | |
| [0.0] * len(user_ob_tokens) | |
| + [0.0] * len(user_ac_tokens) | |
| + [0.0] * len(assistant_ob_tokens) | |
| + [1.0] * len(assistant_ac_tokens) | |
| ) | |
| # CRITICAL: Shift tokens AND weights for next-token prediction | |
| input_tokens_model = all_tokens[:-1] | |
| target_tokens = all_tokens[1:] | |
| weights_shifted = weights[1:] | |
| return types.Datum( | |
| model_input=types.ModelInput.from_ints(tokens=input_tokens_model), | |
| loss_fn_inputs=dict(weights=weights_shifted, target_tokens=target_tokens), | |
| ) | |
| def compute_validation_loss(val_data, batch_size=100): | |
| """Compute loss on validation set (forward only, no backward)""" | |
| # Sample a batch from validation set | |
| batch_indices = np.random.choice( | |
| len(val_data), size=min(batch_size, len(val_data)), replace=False | |
| ) | |
| batch = [val_data[i] for i in batch_indices] | |
| # Forward pass only (no backward!) | |
| fwd_future = training_client.forward(batch, loss_fn="cross_entropy") | |
| fwd_result = with_retry(fwd_future) | |
| # Calculate per-token loss | |
| loss_sum = fwd_result.metrics["loss:sum"] | |
| total_completion_tokens = sum( | |
| np.sum(np.array(val_data[i].loss_fn_inputs["weights"].data) > 0) | |
| for i in batch_indices | |
| ) | |
| per_token_loss = ( | |
| loss_sum / total_completion_tokens if total_completion_tokens > 0 else 0 | |
| ) | |
| return per_token_loss | |
| print("\n[4/7] Processing training and validation data...") | |
| print(" Filtering examples that exceed max sequence length (32,000 tokens)...") | |
| # Process and filter training data | |
| training_data_raw_processed = [prepare_datum(example) for example in train_data_raw] | |
| training_data = [d for d in training_data_raw_processed if d is not None] | |
| train_skipped = len(train_data_raw) - len(training_data) | |
| # Process and filter validation data | |
| val_data_raw_processed = [prepare_datum(example) for example in val_data_raw] | |
| val_data = [d for d in val_data_raw_processed if d is not None] | |
| val_skipped = len(val_data_raw) - len(val_data) | |
| print( | |
| f"β Processed {len(training_data)} train examples (skipped {train_skipped} too-long)" | |
| ) | |
| print( | |
| f"β Processed {len(val_data)} validation examples (skipped {val_skipped} too-long)" | |
| ) | |
| # ============================================================================ | |
| # 5. TRAINING LOOP WITH VALIDATION TRACKING | |
| # ============================================================================ | |
| print("\n[5/7] Starting training loop...") | |
| print("-" * 60) | |
| # Training configuration | |
| n_samples = len(training_data) # 7,690 | |
| n_epochs = 4 # Increased for smaller dataset | |
| batch_size = 32 | |
| # Calculate optimal LR for Qwen3-8B with LoRA (from hyperparam_utils formula) | |
| # Formula: base_lr * lora_multiplier * (2000 / hidden_size) ** exponent | |
| # For Qwen: base=5e-5, multiplier=10, hidden_size=4096, exponent=0.0775 | |
| learning_rate = 5e-5 * 10.0 * (2000 / 4096) ** 0.0775 # β 3.7e-4 | |
| print(f"Calculated learning rate: {learning_rate:.6f}") | |
| # Warmup configuration | |
| warmup_steps = 200 | |
| print(f"Warmup steps: {warmup_steps}") | |
| # Calculate iterations from epochs | |
| num_iterations = n_epochs * (n_samples // batch_size) # 4 * 240 = ~961 | |
| checkpoint_interval = 200 | |
| validation_interval = 50 | |
| print(f"Total iterations: {num_iterations}") | |
| print(f"Checkpoints every: {checkpoint_interval} iterations") | |
| print(f"Validation every: {validation_interval} iterations") | |
| losses = [] | |
| per_token_losses = [] | |
| val_losses = [] # Store (iteration, val_loss) tuples | |
| for iteration in range(num_iterations): | |
| # Sample random batch | |
| batch_indices = np.random.choice(len(training_data), size=batch_size, replace=False) | |
| batch = [training_data[i] for i in batch_indices] | |
| # Apply learning rate warmup | |
| if iteration < warmup_steps: | |
| current_lr = learning_rate * (iteration + 1) / warmup_steps | |
| else: | |
| current_lr = learning_rate | |
| # API Primitive 1: forward_backward - compute gradients | |
| fwdbwd_future = training_client.forward_backward(batch, loss_fn="cross_entropy") | |
| # API Primitive 2: optim_step - update parameters | |
| optim_future = training_client.optim_step( | |
| types.AdamParams(learning_rate=current_lr) | |
| ) | |
| # Wait for results with retry logic | |
| fwdbwd_result = with_retry(fwdbwd_future) | |
| optim_result = with_retry(optim_future) | |
| # Track loss (from metrics) | |
| loss_sum = fwdbwd_result.metrics["loss:sum"] | |
| # Calculate per-token loss (count tokens with weight > 0) | |
| total_completion_tokens = sum( | |
| np.sum(np.array(training_data[i].loss_fn_inputs["weights"].data) > 0) | |
| for i in batch_indices | |
| ) | |
| per_token_loss = ( | |
| loss_sum / total_completion_tokens if total_completion_tokens > 0 else 0 | |
| ) | |
| losses.append(loss_sum) | |
| per_token_losses.append(per_token_loss) | |
| # Print progress every 10 iterations | |
| if iteration % 10 == 0: | |
| warmup_indicator = "π₯" if iteration < warmup_steps else " " | |
| print( | |
| f"{warmup_indicator} Iteration {iteration:4d} | Train Loss: {per_token_loss:.4f} | LR: {current_lr:.6f}" | |
| ) | |
| # Compute validation loss periodically | |
| if iteration % validation_interval == 0: | |
| val_loss = compute_validation_loss(val_data) | |
| val_losses.append((iteration, val_loss)) | |
| gap = val_loss - per_token_loss | |
| gap_indicator = "β οΈ" if gap > 0.2 else "β" | |
| print( | |
| f" π {gap_indicator} Iteration {iteration:4d} | Train: {per_token_loss:.4f} | Val: {val_loss:.4f} | Gap: {gap:+.4f}" | |
| ) | |
| # Save checkpoints every N iterations | |
| if iteration > 0 and iteration % checkpoint_interval == 0: | |
| print(f" πΎ Saving checkpoint at iteration {iteration}...") | |
| # Save full state (for resuming training) | |
| state_result = training_client.save_state(name=f"fincot-full-state-{iteration}") | |
| state_path = with_retry(state_result).path | |
| print(f" β Full state: {state_path}") | |
| # Also save sampler weights (for inference) | |
| sampler_result = training_client.save_weights_for_sampler( | |
| name=f"fincot-checkpoint-{iteration}" | |
| ) | |
| sampler_path = with_retry(sampler_result).path | |
| print(f" β Sampler weights: {sampler_path}") | |
| print("-" * 60) | |
| print(f"β Training complete!") | |
| # Calculate final metrics | |
| final_train_loss = per_token_losses[-1] | |
| final_val_loss = compute_validation_loss( | |
| val_data, batch_size=200 | |
| ) # Larger batch for final eval | |
| print(f"\nFinal Results:") | |
| print(f" Training Loss: {final_train_loss:.4f}") | |
| print(f" Validation Loss: {final_val_loss:.4f}") | |
| print(f" Train/Val Gap: {final_val_loss - final_train_loss:+.4f}") | |
| if final_val_loss - final_train_loss > 0.2: | |
| print(" β οΈ Warning: Large train/val gap suggests overfitting") | |
| elif abs(final_val_loss - final_train_loss) < 0.1: | |
| print(" β Good: Small train/val gap suggests healthy generalization") | |
| else: | |
| print(" β Moderate: Train/val gap is acceptable") | |
| # ============================================================================ | |
| # 6. SAVE FINAL MODEL | |
| # ============================================================================ | |
| print("\n[6/7] Saving final model...") | |
| # Save final checkpoint (both types) | |
| final_state_result = training_client.save_state(name="fincot-final") | |
| final_state_path = with_retry(final_state_result).path | |
| print(f"β Final full state saved: {final_state_path}") | |
| # API Primitive 3: save_state - persist model weights for sampling | |
| sampling_client = training_client.save_weights_and_get_sampling_client( | |
| name="fincot-qwen3-8b-lora-final" | |
| ) | |
| print("β Final sampler weights saved") | |
| # ============================================================================ | |
| # 7. TEST WITH SAMPLE QUESTIONS | |
| # ============================================================================ | |
| print("\n[7/7] Testing model with sample questions...") | |
| # Configure sampling parameters (increased max_tokens for reasoning) | |
| sampling_params = types.SamplingParams( | |
| max_tokens=400, # Increased to allow for reasoning chains | |
| temperature=0.7, | |
| top_p=0.9, | |
| stop=["<|im_end|>"], # Qwen3 stop token | |
| ) | |
| # Test questions | |
| test_questions = [ | |
| "What are the main risks associated with investing in stocks?", | |
| "How does diversification help reduce portfolio risk?", | |
| "What is the difference between a stock and a bond?", | |
| ] | |
| print("\n" + "=" * 60) | |
| print("SAMPLE OUTPUTS (with reasoning chains)") | |
| print("=" * 60) | |
| for i, question in enumerate(test_questions, 1): | |
| # Format with Qwen3 chat template | |
| prompt_text = f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n" | |
| model_input = types.ModelInput.from_ints(tokenizer.encode(prompt_text)) | |
| # API Primitive 4: sample - generate predictions (with retry) | |
| response = with_retry( | |
| sampling_client.sample( | |
| prompt=model_input, num_samples=1, sampling_params=sampling_params | |
| ) | |
| ) | |
| # Decode response (sequences[0].tokens) | |
| answer = tokenizer.decode(response.sequences[0].tokens) | |
| print(f"\nQ{i}: {question}") | |
| print(f"A{i}: {answer[:300]}{'...' if len(answer) > 300 else ''}") | |
| print("-" * 60) | |
| print("\nβ All 4 API primitives demonstrated successfully!") | |
| print(" 1. forward_backward β") | |
| print(" 2. optim_step β") | |
| print(" 3. save_state β") | |
| print(" 4. sample β") | |
| print("\nπ Training Summary:") | |
| print(f" Dataset: FinCoT (7,690 train + 500 val)") | |
| print(f" Epochs: {n_epochs}") | |
| print(f" Total iterations: {num_iterations}") | |
| print(f" Initial loss: {per_token_losses[0]:.4f}") | |
| print(f" Final train loss: {final_train_loss:.4f}") | |
| print(f" Final val loss: {final_val_loss:.4f}") | |
| print(f" Improvement: {per_token_losses[0] - final_train_loss:.4f}") | |
| print("\n" + "=" * 60) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment