Created
January 26, 2026 15:55
-
-
Save Norod/5c23ab87fc614ccc71532ece7fcebe52 to your computer and use it in GitHub Desktop.
Adds MPS (Metal) GPU support for Real-ESRGAN and GFPGAN on Apple Silicon by patching their device selection logic
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Patch Real-ESRGAN and GFPGAN to support Apple MPS (Metal Performance Shaders). | |
| This script adds MPS device detection to both packages, enabling GPU acceleration | |
| on Apple Silicon Macs. | |
| Usage: | |
| python patch_mps_support.py [--realesrgan-path /path/to/Real-ESRGAN] | |
| If --realesrgan-path is not provided, only gfpgan (from site-packages) is patched. | |
| """ | |
| import argparse | |
| import re | |
| import sys | |
| from pathlib import Path | |
| # Old patterns to find and replace | |
| REALESRGAN_OLD = ''' # initialize model | |
| if gpu_id: | |
| self.device = torch.device( | |
| f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device | |
| else: | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device''' | |
| REALESRGAN_NEW = ''' # initialize model | |
| if device is not None: | |
| self.device = device | |
| elif gpu_id: | |
| self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') | |
| elif torch.cuda.is_available(): | |
| self.device = torch.device('cuda') | |
| elif torch.backends.mps.is_available(): | |
| self.device = torch.device('mps') | |
| else: | |
| self.device = torch.device('cpu')''' | |
| GFPGAN_OLD = ''' # initialize model | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device''' | |
| GFPGAN_NEW = ''' # initialize model | |
| if device is not None: | |
| self.device = device | |
| elif torch.cuda.is_available(): | |
| self.device = torch.device('cuda') | |
| elif torch.backends.mps.is_available(): | |
| self.device = torch.device('mps') | |
| else: | |
| self.device = torch.device('cpu')''' | |
| def patch_file(filepath: Path, old: str, new: str, name: str) -> bool: | |
| """Patch a file by replacing old content with new content.""" | |
| if not filepath.exists(): | |
| print(f"[SKIP] {name}: File not found at {filepath}") | |
| return False | |
| content = filepath.read_text() | |
| if new in content: | |
| print(f"[OK] {name}: Already patched") | |
| return True | |
| if old not in content: | |
| print(f"[WARN] {name}: Original pattern not found - may already be modified") | |
| return False | |
| new_content = content.replace(old, new) | |
| filepath.write_text(new_content) | |
| print(f"[DONE] {name}: Patched successfully") | |
| return True | |
| def find_gfpgan_path(): | |
| """Find gfpgan in site-packages.""" | |
| try: | |
| import gfpgan | |
| return Path(gfpgan.__file__).parent / "utils.py" | |
| except ImportError: | |
| return None | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Patch Real-ESRGAN and GFPGAN for MPS support") | |
| parser.add_argument( | |
| "--realesrgan-path", | |
| type=Path, | |
| help="Path to Real-ESRGAN repository (optional)" | |
| ) | |
| parser.add_argument( | |
| "--dry-run", | |
| action="store_true", | |
| help="Show what would be patched without making changes" | |
| ) | |
| args = parser.parse_args() | |
| print("=" * 60) | |
| print("Real-ESRGAN & GFPGAN MPS Support Patcher") | |
| print("=" * 60) | |
| print() | |
| # Check MPS availability | |
| try: | |
| import torch | |
| mps_available = torch.backends.mps.is_available() | |
| mps_built = torch.backends.mps.is_built() | |
| print(f"PyTorch MPS available: {mps_available}") | |
| print(f"PyTorch MPS built: {mps_built}") | |
| if not mps_available: | |
| print("[WARN] MPS not available - patches will still be applied but won't have effect") | |
| except ImportError: | |
| print("[WARN] PyTorch not installed - cannot check MPS availability") | |
| print() | |
| success = True | |
| # Patch GFPGAN | |
| gfpgan_path = find_gfpgan_path() | |
| if gfpgan_path: | |
| if args.dry_run: | |
| print(f"[DRY] Would patch GFPGAN at: {gfpgan_path}") | |
| else: | |
| success &= patch_file(gfpgan_path, GFPGAN_OLD, GFPGAN_NEW, "GFPGAN") | |
| else: | |
| print("[SKIP] GFPGAN: Not installed") | |
| # Patch Real-ESRGAN | |
| if args.realesrgan_path: | |
| realesrgan_utils = args.realesrgan_path / "realesrgan" / "utils.py" | |
| if args.dry_run: | |
| print(f"[DRY] Would patch Real-ESRGAN at: {realesrgan_utils}") | |
| else: | |
| success &= patch_file(realesrgan_utils, REALESRGAN_OLD, REALESRGAN_NEW, "Real-ESRGAN") | |
| else: | |
| print("[SKIP] Real-ESRGAN: No --realesrgan-path provided") | |
| print() | |
| print("=" * 60) | |
| if success: | |
| print("Patching complete!") | |
| else: | |
| print("Patching completed with warnings - check output above") | |
| return 0 if success else 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment