Created
August 26, 2025 03:15
-
-
Save karminski/a173fa105a1078e596973ca34b326969 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| # --- 1. Clone the VibeVoice Repository --- | |
| repo_dir = "VibeVoice" | |
| if not os.path.exists(repo_dir): | |
| print("Cloning the VibeVoice repository...") | |
| try: | |
| subprocess.run( | |
| ["git", "clone", "https://github.com/microsoft/VibeVoice.git"], | |
| check=True, | |
| capture_output=True, | |
| text=True, | |
| ) | |
| print("Repository cloned successfully.") | |
| except subprocess.CalledProcessError as e: | |
| print(f"Error cloning repository: {e.stderr}") | |
| sys.exit(1) | |
| else: | |
| print("Repository already exists. Skipping clone.") | |
| # --- 2. Install the Package --- | |
| os.chdir(repo_dir) | |
| print(f"Changed directory to: {os.getcwd()}") | |
| print("Installing the VibeVoice package...") | |
| try: | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "-e", "."], | |
| check=True, | |
| capture_output=True, | |
| text=True, | |
| ) | |
| print("Package installed successfully.") | |
| except subprocess.CalledProcessError as e: | |
| print(f"Error installing package: {e.stderr}") | |
| sys.exit(1) | |
| # --- 3. Modify the demo script for GPU execution --- | |
| demo_script_path = Path("demo/gradio_demo.py") | |
| print(f"Modifying {demo_script_path} for GPU execution...") | |
| try: | |
| # Read the entire file content | |
| file_content = demo_script_path.read_text() | |
| # 原始的GPU配置(可能包含flash_attention_2): | |
| original_block_with_flash = """ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( | |
| self.model_path, | |
| torch_dtype=torch.bfloat16, | |
| device_map='cuda', | |
| attn_implementation="flash_attention_2", | |
| )""" | |
| original_block_without_flash = """ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( | |
| self.model_path, | |
| torch_dtype=torch.bfloat16, | |
| device_map='cuda', | |
| )""" | |
| # GPU配置但不使用flash_attention_2: | |
| replacement_block = """ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( | |
| self.model_path, | |
| torch_dtype=torch.bfloat16, | |
| device_map='cuda', | |
| )""" | |
| # Replace the entire block | |
| if original_block_with_flash in file_content: | |
| modified_content = file_content.replace(original_block_with_flash, replacement_block) | |
| demo_script_path.write_text(modified_content) | |
| print("Script modified successfully (removed flash_attention_2).") | |
| elif original_block_without_flash in file_content: | |
| print("GPU configuration already correct (no flash_attention_2).") | |
| else: | |
| print("Warning: GPU-specific model loading block not found. The script might have been updated. Proceeding without modification.") | |
| except Exception as e: | |
| print(f"An error occurred while modifying the script: {e}") | |
| sys.exit(1) | |
| # --- 4. Launch the Gradio Demo --- | |
| model_id = "microsoft/VibeVoice-1.5B" | |
| # Construct the command as specified in the README | |
| command = ["python", str(demo_script_path), "--model_path", model_id, "--share"] | |
| print(f"Launching Gradio demo with command: {' '.join(command)}") | |
| # This command will start the Gradio server | |
| subprocess.run(command) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment