Last active
February 21, 2025 18:08
-
-
Save AcTePuKc/4cfe0426aae1c50048382302767f9f8d to your computer and use it in GitHub Desktop.
Kokoro-ONNX Benchmark CPU/GPU
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import time | |
| import csv | |
| import multiprocessing | |
| import platform | |
| import torch | |
| import subprocess | |
| import webbrowser | |
| from datetime import datetime | |
| from tabulate import tabulate | |
| from PIL import Image, ImageDraw, ImageFont | |
| from kokoro_onnx import Kokoro # Assuming kokoro_onnx is installed or in the same directory. | |
| import sys | |
| # Get system information (CPU & GPU detection) - More robust handling | |
| def get_system_info(): | |
| cpu_name = "Unknown CPU" | |
| gpu_name = "No GPU" | |
| try: | |
| if platform.system() == "Windows": | |
| try: | |
| cpu_name = subprocess.check_output("wmic cpu get name", shell=True, stderr=subprocess.DEVNULL, text=True).split("\n")[1].strip() | |
| except (subprocess.CalledProcessError, FileNotFoundError, OSError): | |
| pass # Fallback to platform.processor() below | |
| else: # Linux/macOS | |
| try: | |
| cpu_name = subprocess.check_output("lscpu | grep 'Model name'", shell=True, stderr=subprocess.DEVNULL, text=True).split(":")[1].strip() | |
| except (subprocess.CalledProcessError, FileNotFoundError, OSError): | |
| pass | |
| if cpu_name == "Unknown CPU": # Fallback if subprocess fails | |
| cpu_name = platform.processor() or "Unknown CPU" | |
| if torch.cuda.is_available(): | |
| try: | |
| gpu_name = torch.cuda.get_device_name(0) | |
| except Exception: | |
| gpu_name = "GPU Detected (Error getting name)" # More informative | |
| else: | |
| gpu_name = "No GPU" | |
| except Exception as e: | |
| print(f"Error getting system info: {e}") # Log any unexpected errors | |
| return cpu_name, gpu_name | |
| # Paths and timestamp for unique file names - Use consistent separators | |
| def initialize_paths(): | |
| model_folder = os.path.dirname(os.path.abspath(__file__)) # Get current script's directory | |
| benchmark_folder = os.path.join(model_folder, "benchmark") # Define benchmark folder path | |
| os.makedirs(benchmark_folder, exist_ok=True) # Create the folder if it doesn't exist | |
| timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| csv_results_path = os.path.join(benchmark_folder, f"benchmark_results_{timestamp}.csv") | |
| html_results_path = os.path.join(benchmark_folder, f"benchmark_results_{timestamp}.html") | |
| badge_path = os.path.join(benchmark_folder, f"kokoro_badge_{timestamp}.png") | |
| voice_file = os.path.join(model_folder, "voices-v1.0.bin") # Assumes this always exists | |
| # Check if the voice file exists; critical for operation | |
| if not os.path.exists(voice_file): | |
| print(f"β ERROR: voice file not found: {voice_file}. Please ensure it exists.") | |
| sys.exit(1) # Exit; cannot continue without the voice file | |
| return model_folder, benchmark_folder, timestamp, csv_results_path, html_results_path, badge_path, voice_file | |
| # Find ONNX models and validate their existence | |
| def find_onnx_models(model_folder): | |
| onnx_models = [f for f in os.listdir(model_folder) if f.endswith(".onnx")] | |
| onnx_models.sort() | |
| if not onnx_models: | |
| print("β ERROR: No .onnx models found in the directory.") | |
| sys.exit(1) # Exit; benchmark cannot proceed without models | |
| return onnx_models | |
| # Load voices dynamically from *a* model - Handle potential errors | |
| def load_voices(model_folder, onnx_models, voice_file): | |
| try: | |
| # Use the first available model to load voices (assuming all have the same voices) | |
| kokoro_test = Kokoro(os.path.join(model_folder, onnx_models[0]), voice_file) | |
| return list(kokoro_test.voices.keys()) | |
| except Exception as e: | |
| print(f"β ERROR: Could not load voices from the model: {e}") | |
| sys.exit(1) | |
| # Benchmark function (for multiprocessing) - More detailed error handling | |
| def benchmark_voice(args): | |
| voice, model_files, model_folder, voice_file = args # Unpack all necessary arguments | |
| results = [] | |
| for model_file in model_files: | |
| model_path = os.path.join(model_folder, model_file) | |
| if not os.path.exists(model_path): | |
| print(f"β ERROR: Model file not found: {model_path}") | |
| continue # Skip this model, don't crash the whole process | |
| try: | |
| kokoro = Kokoro(model_path, voice_file) | |
| except Exception as e: | |
| print(f"β ERROR: Could not load model {model_file}: {e}") | |
| continue # Skip this model, don't crash the whole process | |
| for sentence in ["Hello, world!", "This is a test sentence.", "Kokoro TTS is amazing!"]: | |
| for iteration in range(1, 3): | |
| try: | |
| print(f"βοΈ Benchmarking voice: {voice}, Model: {model_file}, Iteration {iteration}, Sentence: '{sentence}'") | |
| start_time = time.time() | |
| samples, sample_rate = kokoro.create(sentence, voice=voice, speed=1.0, lang="en-us") | |
| end_time = time.time() | |
| inference_time = end_time - start_time | |
| # Handle potential division by zero more gracefully | |
| audio_length = (len(samples) / sample_rate) if (sample_rate > 0 and len(samples) > 0) else 0 # Set to 0, not inf | |
| rtf = (inference_time / audio_length) if (audio_length > 0 and inference_time > 0) else float('inf') # inf if audio_length is 0 | |
| results.append([model_file, voice, iteration, inference_time, audio_length, rtf, sentence]) | |
| except Exception as e: | |
| print(f"β ERROR: {voice} on {model_file} | {e}") | |
| results.append([model_file, voice, iteration, None, None, float('inf'), sentence]) # Record the error. | |
| return results | |
| # Main execution block | |
| if __name__ == "__main__": | |
| model_folder, benchmark_folder, timestamp, csv_results_path, html_results_path, badge_path, voice_file = initialize_paths() | |
| onnx_models = find_onnx_models(model_folder) | |
| voice_list = load_voices(model_folder, onnx_models, voice_file) | |
| print("\nπ **Available Voices:**") | |
| for i, voice in enumerate(voice_list): | |
| print(f" [{i+1}] {voice}") | |
| selected_voices = [] | |
| while not selected_voices: | |
| try: | |
| user_input = input("\nEnter voice numbers to test (comma-separated, e.g., '1,5,7'): ").strip() | |
| if not user_input: | |
| print("β No input detected! Please enter at least one valid number.") | |
| continue | |
| selected_indices = [] | |
| for i in user_input.split(','): | |
| i = i.strip() | |
| if i.isdigit(): | |
| index = int(i) - 1 | |
| if 0 <= index < len(voice_list): | |
| selected_indices.append(index) | |
| else: | |
| print(f"β Invalid selection: {i}. Pick numbers between 1 and {len(voice_list)}") | |
| elif i: # Check if it's not just whitespace | |
| print(f"β Invalid input: '{i}'. Please enter numbers only.") | |
| if not selected_indices: # If after filtering, the list is empty | |
| print("β No valid voices selected. Please try again.") | |
| continue | |
| selected_voices = [voice_list[i] for i in selected_indices] | |
| except ValueError: | |
| print("β Invalid input! Please enter numbers separated by commas.") | |
| except Exception as e: | |
| print(f"β An unexpected error occurred during input processing: {e}") | |
| # Determine number of processes - Prevent excessive processes | |
| user_parallel = max(1, min(len(selected_voices), multiprocessing.cpu_count())) | |
| print(f"\nβ³ Processing... Please wait. Using {user_parallel} processes.") | |
| try: | |
| with multiprocessing.Pool(user_parallel) as pool: | |
| all_results = pool.map(benchmark_voice, [(voice, onnx_models, model_folder, voice_file) for voice in selected_voices]) | |
| except KeyboardInterrupt: | |
| print("\nβ Benchmark canceled by user.") | |
| pool.terminate() | |
| pool.join() | |
| sys.exit(1) | |
| except Exception as e: | |
| print(f"\nβ An unexpected error occurred during multiprocessing: {e}") | |
| sys.exit(1) | |
| flat_results = [item for sublist in all_results for item in sublist] | |
| # CSV writing - More robust handling | |
| try: | |
| with open(csv_results_path, "w", newline="", encoding="utf-8") as csvfile: | |
| csv_writer = csv.writer(csvfile) | |
| csv_writer.writerow(["Model", "Voice", "Iteration", "Synthesis Time", "Audio Length", "RTF", "Sentence"]) | |
| csv_writer.writerows(flat_results) | |
| print(f"Benchmark complete! Results saved to {csv_results_path}") | |
| except Exception as e: | |
| print(f"β Error writing CSV file: {e}") | |
| # Generate benchmark badge - Handle cases with no valid RTF | |
| def generate_badge(results, num_voices, num_models, timestamp, badge_path): | |
| try: | |
| valid_rtfs = [row[-2] for row in results if isinstance(row[-2], (int, float)) and row[-2] > 0] | |
| min_rtf = min(valid_rtfs) if valid_rtfs else float('inf') # Handle empty list | |
| cpu, gpu = get_system_info() | |
| img = Image.new("RGB", (600, 250), "black") | |
| draw = ImageDraw.Draw(img) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 26) | |
| small_font = ImageFont.truetype("arial.ttf", 16) | |
| except IOError: | |
| print("Warning: arial.ttf not found, using default font. Badge may look different.") | |
| font = ImageFont.load_default() | |
| small_font = ImageFont.load_default() | |
| # Rank logic (handle inf RTF) | |
| if min_rtf == float('inf'): | |
| rank = "N/A" | |
| color = "grey" | |
| elif min_rtf < 0.5: | |
| rank = "Speed Demon" | |
| color = "gold" | |
| elif min_rtf < 1: | |
| rank = "Respectable Synthesizer" | |
| color = "silver" | |
| elif min_rtf < 2: | |
| rank = "Not Bad, Not Great" | |
| color = "orange" | |
| elif min_rtf < 5: | |
| rank = "The Chugging Potato" | |
| color = "brown" | |
| else: | |
| rank = "BRO JUST USE A GPU" | |
| color = "red" | |
| # Draw text on badge. | |
| draw.text((20, 20), "Kokoro-ONNX Benchmark", fill="white", font=font) | |
| draw.text((20, 90), f"Best RTF: {min_rtf:.3f}" if min_rtf != float('inf') else "Best RTF: N/A", fill="cyan", font=font) | |
| draw.text((20, 150), f"CPU: {cpu}", fill="lime", font=small_font) | |
| draw.text((20, 180), f"GPU: {gpu}", fill="lime", font=small_font) | |
| draw.text((20, 210), f"Voices: {num_voices}, Models: {num_models}", fill="orange", font=small_font) | |
| draw.text((400, 90), rank, fill=color, font=font) | |
| img.save(badge_path) | |
| except Exception as e: | |
| print(f"β Error generating badge: {e}") | |
| generate_badge(flat_results, len(selected_voices), len(onnx_models), timestamp, badge_path) | |
| # HTML report generation - Use try-except and handle missing badge | |
| try: | |
| # Prepare table data, replacing None with "N/A" | |
| table_data = [] | |
| for row in flat_results: | |
| cleaned_row = [item if item is not None else "N/A" for item in row] | |
| table_data.append(cleaned_row) | |
| html_content = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Kokoro Benchmark</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; margin: 20px; }} | |
| h2 {{ color: #333; }} | |
| table {{ width: 100%; border-collapse: collapse; }} | |
| th, td {{ padding: 10px; border: 1px solid #ddd; text-align: left; }} | |
| th {{ background-color: #f4f4f4; }} | |
| button {{ margin-top: 10px; padding: 10px; font-size: 16px; cursor: pointer; }} | |
| </style> | |
| </head> | |
| <body> | |
| <h2>Kokoro ONNX Benchmark Results ({timestamp})</h2> | |
| """ | |
| if os.path.exists(badge_path): | |
| html_content += f"<img src='{badge_path}'><br>" | |
| else: | |
| html_content += "<p>Badge could not be generated.</p>" | |
| html_content += f""" | |
| <button onclick="navigator.clipboard.writeText(`{tabulate(table_data, headers=['Model', 'Voice', 'Iteration', 'Synthesis Time', 'Audio Length', 'RTF', 'Sentence'], tablefmt='plain')}`)">Copy Results</button> | |
| {tabulate(table_data, headers=['Model', 'Voice', 'Iteration', 'Synthesis Time', 'Audio Length', 'RTF', 'Sentence'], tablefmt='html')} | |
| </body> | |
| </html> | |
| """ | |
| with open(html_results_path, "w", encoding="utf-8") as f: | |
| f.write(html_content) | |
| print(f"HTML report saved to {html_results_path}") | |
| webbrowser.open(html_results_path) | |
| except Exception as e: | |
| print(f"β Error generating HTML report: {e}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment