Skip to content

Instantly share code, notes, and snippets.

@Norod
Created January 26, 2026 15:55
Show Gist options
  • Select an option

  • Save Norod/5c23ab87fc614ccc71532ece7fcebe52 to your computer and use it in GitHub Desktop.

Select an option

Save Norod/5c23ab87fc614ccc71532ece7fcebe52 to your computer and use it in GitHub Desktop.
Adds MPS (Metal) GPU support for Real-ESRGAN and GFPGAN on Apple Silicon by patching their device selection logic
#!/usr/bin/env python3
"""
Patch Real-ESRGAN and GFPGAN to support Apple MPS (Metal Performance Shaders).
This script adds MPS device detection to both packages, enabling GPU acceleration
on Apple Silicon Macs.
Usage:
python patch_mps_support.py [--realesrgan-path /path/to/Real-ESRGAN]
If --realesrgan-path is not provided, only gfpgan (from site-packages) is patched.
"""
import argparse
import re
import sys
from pathlib import Path
# Old patterns to find and replace
REALESRGAN_OLD = ''' # initialize model
if gpu_id:
self.device = torch.device(
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
else:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device'''
REALESRGAN_NEW = ''' # initialize model
if device is not None:
self.device = device
elif gpu_id:
self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')
elif torch.cuda.is_available():
self.device = torch.device('cuda')
elif torch.backends.mps.is_available():
self.device = torch.device('mps')
else:
self.device = torch.device('cpu')'''
GFPGAN_OLD = ''' # initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device'''
GFPGAN_NEW = ''' # initialize model
if device is not None:
self.device = device
elif torch.cuda.is_available():
self.device = torch.device('cuda')
elif torch.backends.mps.is_available():
self.device = torch.device('mps')
else:
self.device = torch.device('cpu')'''
def patch_file(filepath: Path, old: str, new: str, name: str) -> bool:
"""Patch a file by replacing old content with new content."""
if not filepath.exists():
print(f"[SKIP] {name}: File not found at {filepath}")
return False
content = filepath.read_text()
if new in content:
print(f"[OK] {name}: Already patched")
return True
if old not in content:
print(f"[WARN] {name}: Original pattern not found - may already be modified")
return False
new_content = content.replace(old, new)
filepath.write_text(new_content)
print(f"[DONE] {name}: Patched successfully")
return True
def find_gfpgan_path():
"""Find gfpgan in site-packages."""
try:
import gfpgan
return Path(gfpgan.__file__).parent / "utils.py"
except ImportError:
return None
def main():
parser = argparse.ArgumentParser(description="Patch Real-ESRGAN and GFPGAN for MPS support")
parser.add_argument(
"--realesrgan-path",
type=Path,
help="Path to Real-ESRGAN repository (optional)"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be patched without making changes"
)
args = parser.parse_args()
print("=" * 60)
print("Real-ESRGAN & GFPGAN MPS Support Patcher")
print("=" * 60)
print()
# Check MPS availability
try:
import torch
mps_available = torch.backends.mps.is_available()
mps_built = torch.backends.mps.is_built()
print(f"PyTorch MPS available: {mps_available}")
print(f"PyTorch MPS built: {mps_built}")
if not mps_available:
print("[WARN] MPS not available - patches will still be applied but won't have effect")
except ImportError:
print("[WARN] PyTorch not installed - cannot check MPS availability")
print()
success = True
# Patch GFPGAN
gfpgan_path = find_gfpgan_path()
if gfpgan_path:
if args.dry_run:
print(f"[DRY] Would patch GFPGAN at: {gfpgan_path}")
else:
success &= patch_file(gfpgan_path, GFPGAN_OLD, GFPGAN_NEW, "GFPGAN")
else:
print("[SKIP] GFPGAN: Not installed")
# Patch Real-ESRGAN
if args.realesrgan_path:
realesrgan_utils = args.realesrgan_path / "realesrgan" / "utils.py"
if args.dry_run:
print(f"[DRY] Would patch Real-ESRGAN at: {realesrgan_utils}")
else:
success &= patch_file(realesrgan_utils, REALESRGAN_OLD, REALESRGAN_NEW, "Real-ESRGAN")
else:
print("[SKIP] Real-ESRGAN: No --realesrgan-path provided")
print()
print("=" * 60)
if success:
print("Patching complete!")
else:
print("Patching completed with warnings - check output above")
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment