Skip to content

Instantly share code, notes, and snippets.

@AcTePuKc
Last active February 21, 2025 18:08
Show Gist options
  • Select an option

  • Save AcTePuKc/4cfe0426aae1c50048382302767f9f8d to your computer and use it in GitHub Desktop.

Select an option

Save AcTePuKc/4cfe0426aae1c50048382302767f9f8d to your computer and use it in GitHub Desktop.
Kokoro-ONNX Benchmark CPU/GPU
import os
import time
import csv
import multiprocessing
import platform
import torch
import subprocess
import webbrowser
from datetime import datetime
from tabulate import tabulate
from PIL import Image, ImageDraw, ImageFont
from kokoro_onnx import Kokoro # Assuming kokoro_onnx is installed or in the same directory.
import sys
# Get system information (CPU & GPU detection) - More robust handling
def get_system_info():
cpu_name = "Unknown CPU"
gpu_name = "No GPU"
try:
if platform.system() == "Windows":
try:
cpu_name = subprocess.check_output("wmic cpu get name", shell=True, stderr=subprocess.DEVNULL, text=True).split("\n")[1].strip()
except (subprocess.CalledProcessError, FileNotFoundError, OSError):
pass # Fallback to platform.processor() below
else: # Linux/macOS
try:
cpu_name = subprocess.check_output("lscpu | grep 'Model name'", shell=True, stderr=subprocess.DEVNULL, text=True).split(":")[1].strip()
except (subprocess.CalledProcessError, FileNotFoundError, OSError):
pass
if cpu_name == "Unknown CPU": # Fallback if subprocess fails
cpu_name = platform.processor() or "Unknown CPU"
if torch.cuda.is_available():
try:
gpu_name = torch.cuda.get_device_name(0)
except Exception:
gpu_name = "GPU Detected (Error getting name)" # More informative
else:
gpu_name = "No GPU"
except Exception as e:
print(f"Error getting system info: {e}") # Log any unexpected errors
return cpu_name, gpu_name
# Paths and timestamp for unique file names - Use consistent separators
def initialize_paths():
model_folder = os.path.dirname(os.path.abspath(__file__)) # Get current script's directory
benchmark_folder = os.path.join(model_folder, "benchmark") # Define benchmark folder path
os.makedirs(benchmark_folder, exist_ok=True) # Create the folder if it doesn't exist
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
csv_results_path = os.path.join(benchmark_folder, f"benchmark_results_{timestamp}.csv")
html_results_path = os.path.join(benchmark_folder, f"benchmark_results_{timestamp}.html")
badge_path = os.path.join(benchmark_folder, f"kokoro_badge_{timestamp}.png")
voice_file = os.path.join(model_folder, "voices-v1.0.bin") # Assumes this always exists
# Check if the voice file exists; critical for operation
if not os.path.exists(voice_file):
print(f"❌ ERROR: voice file not found: {voice_file}. Please ensure it exists.")
sys.exit(1) # Exit; cannot continue without the voice file
return model_folder, benchmark_folder, timestamp, csv_results_path, html_results_path, badge_path, voice_file
# Find ONNX models and validate their existence
def find_onnx_models(model_folder):
onnx_models = [f for f in os.listdir(model_folder) if f.endswith(".onnx")]
onnx_models.sort()
if not onnx_models:
print("❌ ERROR: No .onnx models found in the directory.")
sys.exit(1) # Exit; benchmark cannot proceed without models
return onnx_models
# Load voices dynamically from *a* model - Handle potential errors
def load_voices(model_folder, onnx_models, voice_file):
try:
# Use the first available model to load voices (assuming all have the same voices)
kokoro_test = Kokoro(os.path.join(model_folder, onnx_models[0]), voice_file)
return list(kokoro_test.voices.keys())
except Exception as e:
print(f"❌ ERROR: Could not load voices from the model: {e}")
sys.exit(1)
# Benchmark function (for multiprocessing) - More detailed error handling
def benchmark_voice(args):
voice, model_files, model_folder, voice_file = args # Unpack all necessary arguments
results = []
for model_file in model_files:
model_path = os.path.join(model_folder, model_file)
if not os.path.exists(model_path):
print(f"❌ ERROR: Model file not found: {model_path}")
continue # Skip this model, don't crash the whole process
try:
kokoro = Kokoro(model_path, voice_file)
except Exception as e:
print(f"❌ ERROR: Could not load model {model_file}: {e}")
continue # Skip this model, don't crash the whole process
for sentence in ["Hello, world!", "This is a test sentence.", "Kokoro TTS is amazing!"]:
for iteration in range(1, 3):
try:
print(f"βœ”οΈ Benchmarking voice: {voice}, Model: {model_file}, Iteration {iteration}, Sentence: '{sentence}'")
start_time = time.time()
samples, sample_rate = kokoro.create(sentence, voice=voice, speed=1.0, lang="en-us")
end_time = time.time()
inference_time = end_time - start_time
# Handle potential division by zero more gracefully
audio_length = (len(samples) / sample_rate) if (sample_rate > 0 and len(samples) > 0) else 0 # Set to 0, not inf
rtf = (inference_time / audio_length) if (audio_length > 0 and inference_time > 0) else float('inf') # inf if audio_length is 0
results.append([model_file, voice, iteration, inference_time, audio_length, rtf, sentence])
except Exception as e:
print(f"❌ ERROR: {voice} on {model_file} | {e}")
results.append([model_file, voice, iteration, None, None, float('inf'), sentence]) # Record the error.
return results
# Main execution block
if __name__ == "__main__":
model_folder, benchmark_folder, timestamp, csv_results_path, html_results_path, badge_path, voice_file = initialize_paths()
onnx_models = find_onnx_models(model_folder)
voice_list = load_voices(model_folder, onnx_models, voice_file)
print("\nπŸ”Š **Available Voices:**")
for i, voice in enumerate(voice_list):
print(f" [{i+1}] {voice}")
selected_voices = []
while not selected_voices:
try:
user_input = input("\nEnter voice numbers to test (comma-separated, e.g., '1,5,7'): ").strip()
if not user_input:
print("❌ No input detected! Please enter at least one valid number.")
continue
selected_indices = []
for i in user_input.split(','):
i = i.strip()
if i.isdigit():
index = int(i) - 1
if 0 <= index < len(voice_list):
selected_indices.append(index)
else:
print(f"❌ Invalid selection: {i}. Pick numbers between 1 and {len(voice_list)}")
elif i: # Check if it's not just whitespace
print(f"❌ Invalid input: '{i}'. Please enter numbers only.")
if not selected_indices: # If after filtering, the list is empty
print("❌ No valid voices selected. Please try again.")
continue
selected_voices = [voice_list[i] for i in selected_indices]
except ValueError:
print("❌ Invalid input! Please enter numbers separated by commas.")
except Exception as e:
print(f"❌ An unexpected error occurred during input processing: {e}")
# Determine number of processes - Prevent excessive processes
user_parallel = max(1, min(len(selected_voices), multiprocessing.cpu_count()))
print(f"\n⏳ Processing... Please wait. Using {user_parallel} processes.")
try:
with multiprocessing.Pool(user_parallel) as pool:
all_results = pool.map(benchmark_voice, [(voice, onnx_models, model_folder, voice_file) for voice in selected_voices])
except KeyboardInterrupt:
print("\n❌ Benchmark canceled by user.")
pool.terminate()
pool.join()
sys.exit(1)
except Exception as e:
print(f"\n❌ An unexpected error occurred during multiprocessing: {e}")
sys.exit(1)
flat_results = [item for sublist in all_results for item in sublist]
# CSV writing - More robust handling
try:
with open(csv_results_path, "w", newline="", encoding="utf-8") as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow(["Model", "Voice", "Iteration", "Synthesis Time", "Audio Length", "RTF", "Sentence"])
csv_writer.writerows(flat_results)
print(f"Benchmark complete! Results saved to {csv_results_path}")
except Exception as e:
print(f"❌ Error writing CSV file: {e}")
# Generate benchmark badge - Handle cases with no valid RTF
def generate_badge(results, num_voices, num_models, timestamp, badge_path):
try:
valid_rtfs = [row[-2] for row in results if isinstance(row[-2], (int, float)) and row[-2] > 0]
min_rtf = min(valid_rtfs) if valid_rtfs else float('inf') # Handle empty list
cpu, gpu = get_system_info()
img = Image.new("RGB", (600, 250), "black")
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("arial.ttf", 26)
small_font = ImageFont.truetype("arial.ttf", 16)
except IOError:
print("Warning: arial.ttf not found, using default font. Badge may look different.")
font = ImageFont.load_default()
small_font = ImageFont.load_default()
# Rank logic (handle inf RTF)
if min_rtf == float('inf'):
rank = "N/A"
color = "grey"
elif min_rtf < 0.5:
rank = "Speed Demon"
color = "gold"
elif min_rtf < 1:
rank = "Respectable Synthesizer"
color = "silver"
elif min_rtf < 2:
rank = "Not Bad, Not Great"
color = "orange"
elif min_rtf < 5:
rank = "The Chugging Potato"
color = "brown"
else:
rank = "BRO JUST USE A GPU"
color = "red"
# Draw text on badge.
draw.text((20, 20), "Kokoro-ONNX Benchmark", fill="white", font=font)
draw.text((20, 90), f"Best RTF: {min_rtf:.3f}" if min_rtf != float('inf') else "Best RTF: N/A", fill="cyan", font=font)
draw.text((20, 150), f"CPU: {cpu}", fill="lime", font=small_font)
draw.text((20, 180), f"GPU: {gpu}", fill="lime", font=small_font)
draw.text((20, 210), f"Voices: {num_voices}, Models: {num_models}", fill="orange", font=small_font)
draw.text((400, 90), rank, fill=color, font=font)
img.save(badge_path)
except Exception as e:
print(f"❌ Error generating badge: {e}")
generate_badge(flat_results, len(selected_voices), len(onnx_models), timestamp, badge_path)
# HTML report generation - Use try-except and handle missing badge
try:
# Prepare table data, replacing None with "N/A"
table_data = []
for row in flat_results:
cleaned_row = [item if item is not None else "N/A" for item in row]
table_data.append(cleaned_row)
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>Kokoro Benchmark</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
h2 {{ color: #333; }}
table {{ width: 100%; border-collapse: collapse; }}
th, td {{ padding: 10px; border: 1px solid #ddd; text-align: left; }}
th {{ background-color: #f4f4f4; }}
button {{ margin-top: 10px; padding: 10px; font-size: 16px; cursor: pointer; }}
</style>
</head>
<body>
<h2>Kokoro ONNX Benchmark Results ({timestamp})</h2>
"""
if os.path.exists(badge_path):
html_content += f"<img src='{badge_path}'><br>"
else:
html_content += "<p>Badge could not be generated.</p>"
html_content += f"""
<button onclick="navigator.clipboard.writeText(`{tabulate(table_data, headers=['Model', 'Voice', 'Iteration', 'Synthesis Time', 'Audio Length', 'RTF', 'Sentence'], tablefmt='plain')}`)">Copy Results</button>
{tabulate(table_data, headers=['Model', 'Voice', 'Iteration', 'Synthesis Time', 'Audio Length', 'RTF', 'Sentence'], tablefmt='html')}
</body>
</html>
"""
with open(html_results_path, "w", encoding="utf-8") as f:
f.write(html_content)
print(f"HTML report saved to {html_results_path}")
webbrowser.open(html_results_path)
except Exception as e:
print(f"❌ Error generating HTML report: {e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment