|
""" |
|
Smart Robot Episode Dataloader - Solves the variable episode length problem |
|
that causes training bottlenecks in robot learning datasets. |
|
|
|
Key Innovation: Content-aware sharding that balances total compute workload, |
|
not just episode count, across distributed workers. |
|
""" |
|
|
|
import json |
|
import numpy as np |
|
import time |
|
from pathlib import Path |
|
from typing import List, Dict, Tuple, Iterator |
|
from dataclasses import dataclass |
|
from collections import defaultdict |
|
import multiprocessing as mp |
|
from concurrent.futures import ProcessPoolExecutor |
|
import random |
|
|
|
@dataclass |
|
class EpisodeMetadata: |
|
"""Metadata for a single robot episode""" |
|
episode_id: str |
|
num_frames: int |
|
action_dim: int |
|
task_description: str |
|
camera_views: List[str] |
|
estimated_compute_cost: float # frames * action_dim * num_cameras |
|
|
|
class RobotDatasetAnalyzer: |
|
"""Analyzes robot datasets to understand episode length distributions""" |
|
|
|
def __init__(self): |
|
self.episode_stats = [] |
|
|
|
def analyze_episode(self, episode_path: Path) -> EpisodeMetadata: |
|
"""Analyze a single episode to extract metadata""" |
|
# Simulate reading episode metadata (in real implementation, |
|
# this would parse actual robot data files) |
|
episode_id = episode_path.stem |
|
|
|
# Simulate realistic robot episode characteristics |
|
# Based on SmolVLA paper: episodes vary wildly in length |
|
if "pick_place" in episode_id: |
|
num_frames = random.randint(50, 150) # Quick tasks |
|
action_dim = 7 # 6DOF + gripper |
|
elif "cooking" in episode_id: |
|
num_frames = random.randint(200, 800) # Long tasks |
|
action_dim = 7 |
|
elif "navigation" in episode_id: |
|
num_frames = random.randint(100, 400) # Medium tasks |
|
action_dim = 3 # x, y, theta |
|
else: |
|
num_frames = random.randint(30, 600) # Variable |
|
action_dim = random.choice([3, 6, 7]) |
|
|
|
# Camera setup varies by dataset (as shown in SmolVLA filtering tool) |
|
camera_views = random.choice([ |
|
["wrist", "overhead"], |
|
["wrist"], |
|
["overhead", "side", "wrist"], |
|
["ego"] |
|
]) |
|
|
|
# Task descriptions (SmolVLA standardized these) |
|
tasks = [ |
|
"Pick up the cube and place it in the box", |
|
"Open the drawer and put item inside", |
|
"Navigate to the red marker", |
|
"Stack the blocks in order", |
|
"Pour water into the cup" |
|
] |
|
task_description = random.choice(tasks) |
|
|
|
# Compute cost = frames Γ action_dim Γ cameras (proxy for training time) |
|
compute_cost = num_frames * action_dim * len(camera_views) |
|
|
|
return EpisodeMetadata( |
|
episode_id=episode_id, |
|
num_frames=num_frames, |
|
action_dim=action_dim, |
|
task_description=task_description, |
|
camera_views=camera_views, |
|
estimated_compute_cost=compute_cost |
|
) |
|
|
|
def analyze_dataset(self, dataset_path: Path, max_episodes: int = None) -> List[EpisodeMetadata]: |
|
"""Analyze entire dataset to build episode metadata""" |
|
print(f"π Analyzing robot dataset at {dataset_path}") |
|
|
|
# Simulate episode files (in real implementation, glob for actual files) |
|
episode_files = [] |
|
for i in range(100): # Simulate 100 episodes |
|
task_type = random.choice(["pick_place", "cooking", "navigation", "manipulation"]) |
|
episode_files.append(Path(f"episode_{task_type}_{i:03d}")) |
|
|
|
if max_episodes: |
|
episode_files = episode_files[:max_episodes] |
|
|
|
# Analyze episodes in parallel |
|
with ProcessPoolExecutor(max_workers=4) as executor: |
|
episodes = list(executor.map(self.analyze_episode, episode_files)) |
|
|
|
# Print statistics |
|
frame_counts = [ep.num_frames for ep in episodes] |
|
compute_costs = [ep.estimated_compute_cost for ep in episodes] |
|
|
|
print(f"π Dataset Statistics:") |
|
print(f" Episodes: {len(episodes)}") |
|
print(f" Frame count: min={min(frame_counts)}, max={max(frame_counts)}, avg={np.mean(frame_counts):.1f}") |
|
print(f" Compute cost: min={min(compute_costs):.0f}, max={max(compute_costs):.0f}, avg={np.mean(compute_costs):.0f}") |
|
print(f" Length variation: {max(frame_counts)/min(frame_counts):.1f}x") |
|
|
|
return episodes |
|
|
|
class SmartEpisodeBatcher: |
|
"""Smart batching that balances compute load across workers""" |
|
|
|
def __init__(self, episodes: List[EpisodeMetadata]): |
|
self.episodes = episodes |
|
self.total_compute_cost = sum(ep.estimated_compute_cost for ep in episodes) |
|
|
|
def create_balanced_shards(self, num_workers: int) -> List[List[EpisodeMetadata]]: |
|
"""Create balanced shards using bin packing algorithm""" |
|
print(f"π― Creating {num_workers} balanced shards...") |
|
|
|
# Sort episodes by compute cost (largest first) |
|
sorted_episodes = sorted(self.episodes, key=lambda x: x.estimated_compute_cost, reverse=True) |
|
|
|
# Initialize worker bins |
|
worker_shards = [[] for _ in range(num_workers)] |
|
worker_costs = [0.0 for _ in range(num_workers)] |
|
|
|
# Greedy bin packing: assign each episode to least loaded worker |
|
for episode in sorted_episodes: |
|
min_worker = min(range(num_workers), key=lambda i: worker_costs[i]) |
|
worker_shards[min_worker].append(episode) |
|
worker_costs[min_worker] += episode.estimated_compute_cost |
|
|
|
# Print load balancing results |
|
target_cost = self.total_compute_cost / num_workers |
|
print(f"π Load Balancing Results:") |
|
for i, (shard, cost) in enumerate(zip(worker_shards, worker_costs)): |
|
episodes_count = len(shard) |
|
load_ratio = cost / target_cost |
|
print(f" Worker {i}: {episodes_count:2d} episodes, cost={cost:.0f} ({load_ratio:.2f}x target)") |
|
|
|
# Calculate load imbalance |
|
max_cost = max(worker_costs) |
|
min_cost = min(worker_costs) |
|
imbalance = max_cost / min_cost if min_cost > 0 else float('inf') |
|
print(f" Load imbalance: {imbalance:.2f}x (lower is better)") |
|
|
|
return worker_shards |
|
|
|
def naive_sharding_comparison(self, num_workers: int): |
|
"""Show how bad naive episode-count sharding would be""" |
|
print(f"\nβ οΈ Naive Sharding Comparison:") |
|
|
|
episodes_per_worker = len(self.episodes) // num_workers |
|
naive_costs = [] |
|
|
|
for i in range(num_workers): |
|
start_idx = i * episodes_per_worker |
|
end_idx = start_idx + episodes_per_worker |
|
if i == num_workers - 1: # Last worker gets remaining episodes |
|
end_idx = len(self.episodes) |
|
|
|
worker_episodes = self.episodes[start_idx:end_idx] |
|
worker_cost = sum(ep.estimated_compute_cost for ep in worker_episodes) |
|
naive_costs.append(worker_cost) |
|
|
|
print(f" Naive Worker {i}: {len(worker_episodes):2d} episodes, cost={worker_cost:.0f}") |
|
|
|
naive_imbalance = max(naive_costs) / min(naive_costs) if min(naive_costs) > 0 else float('inf') |
|
print(f" Naive imbalance: {naive_imbalance:.2f}x") |
|
|
|
def worker_training_time(shard: List[EpisodeMetadata]) -> float: |
|
"""Simulate training time for one worker's shard""" |
|
total_time = 0 |
|
for episode in shard: |
|
# Simulate processing time (proportional to compute cost) |
|
processing_time = episode.estimated_compute_cost / 1000 # Scale to seconds |
|
total_time += processing_time |
|
time.sleep(0.001) # Tiny sleep to simulate actual work |
|
return total_time |
|
|
|
def simulate_training_performance(shards: List[List[EpisodeMetadata]], method_name: str): |
|
"""Simulate training performance with given sharding strategy""" |
|
print(f"\nπ Simulating Training Performance - {method_name}") |
|
|
|
# Time each worker (simulate parallel execution) |
|
start_time = time.time() |
|
with ProcessPoolExecutor(max_workers=len(shards)) as executor: |
|
worker_times = list(executor.map(worker_training_time, shards)) |
|
|
|
# In distributed training, we're bottlenecked by the slowest worker |
|
total_training_time = max(worker_times) |
|
actual_time = time.time() - start_time |
|
|
|
print(f" Training time per worker: {[f'{t:.2f}s' for t in worker_times]}") |
|
print(f" Bottleneck (slowest worker): {total_training_time:.2f}s") |
|
print(f" Efficiency: {min(worker_times)/max(worker_times)*100:.1f}%") |
|
|
|
return total_training_time |
|
|
|
def main(): |
|
"""Demo the smart robot episode dataloader""" |
|
print("π€ Smart Robot Episode Dataloader Demo") |
|
print("=" * 50) |
|
|
|
# Step 1: Analyze a robot dataset |
|
analyzer = RobotDatasetAnalyzer() |
|
dataset_path = Path("./mock_robot_dataset") # Would be real path in practice |
|
episodes = analyzer.analyze_dataset(dataset_path, max_episodes=50) |
|
|
|
# Step 2: Create smart balanced shards |
|
batcher = SmartEpisodeBatcher(episodes) |
|
num_workers = 4 |
|
balanced_shards = batcher.create_balanced_shards(num_workers) |
|
|
|
# Step 3: Compare with naive sharding |
|
batcher.naive_sharding_comparison(num_workers) |
|
|
|
# Step 4: Simulate training performance |
|
smart_time = simulate_training_performance(balanced_shards, "Smart Balancing") |
|
|
|
# Create naive shards for comparison |
|
episodes_per_worker = len(episodes) // num_workers |
|
naive_shards = [] |
|
for i in range(num_workers): |
|
start_idx = i * episodes_per_worker |
|
end_idx = start_idx + episodes_per_worker |
|
if i == num_workers - 1: |
|
end_idx = len(episodes) |
|
naive_shards.append(episodes[start_idx:end_idx]) |
|
|
|
naive_time = simulate_training_performance(naive_shards, "Naive Sharding") |
|
|
|
# Step 5: Show the improvement |
|
speedup = naive_time / smart_time |
|
print(f"\nπ Results Summary:") |
|
print(f" Smart balancing training time: {smart_time:.2f}s") |
|
print(f" Naive sharding training time: {naive_time:.2f}s") |
|
print(f" Speedup: {speedup:.2f}x faster") |
|
print(f" Time saved: {naive_time - smart_time:.2f}s ({(speedup-1)*100:.1f}% improvement)") |
|
|
|
if __name__ == "__main__": |
|
main() |