Created
November 22, 2025 01:38
-
-
Save MohaElder/5c4c2eb2f5f3a5e47a98c0c81874c780 to your computer and use it in GitHub Desktop.
A simple backend service to call sam3d model inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from flask import Flask, request, jsonify, send_file | |
| from flask_cors import CORS | |
| import torch | |
| import sys | |
| import os | |
| sys.path.append("notebook") | |
| from inference import Inference | |
| from PIL import Image | |
| import base64 | |
| import io | |
| from datetime import datetime | |
| import traceback | |
| from functools import wraps | |
| import numpy as np | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for all routes | |
| # API Security - Set API_SECRET environment variable | |
| API_SECRET = os.environ.get('API_SECRET', None) | |
| def require_api_key(f): | |
| """Decorator to require API key in X-API-Key header""" | |
| @wraps(f) | |
| def decorated_function(*args, **kwargs): | |
| if API_SECRET is None: | |
| # No security configured, allow access | |
| return f(*args, **kwargs) | |
| provided_key = request.headers.get('X-API-Key') | |
| if not provided_key: | |
| return jsonify({ | |
| "success": False, | |
| "error": "Missing X-API-Key header" | |
| }), 401 | |
| if provided_key != API_SECRET: | |
| return jsonify({ | |
| "success": False, | |
| "error": "Invalid API key" | |
| }), 403 | |
| return f(*args, **kwargs) | |
| return decorated_function | |
| # Initialize model | |
| print("Loading SAM 3D Objects model...") | |
| config_path = "checkpoints/hf/checkpoints/pipeline.yaml" | |
| if not os.path.exists(config_path): | |
| print(f"ERROR: Config file not found at {config_path}") | |
| print("Available files in checkpoints/hf/checkpoints/:") | |
| os.system("ls -la checkpoints/hf/checkpoints/") | |
| sys.exit(1) | |
| try: | |
| inference_model = Inference(config_path, compile=False) | |
| # Enable optimizations (works for all RTX GPUs) | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.set_float32_matmul_precision('high') | |
| print(f"✓ Model loaded on: {torch.cuda.get_device_name(0)}") | |
| print(f"✓ Total VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") | |
| except Exception as e: | |
| print(f"ERROR loading model: {e}") | |
| traceback.print_exc() | |
| sys.exit(1) | |
| @app.route('/health', methods=['GET']) | |
| @require_api_key | |
| def health(): | |
| return jsonify({ | |
| "status": "healthy", | |
| "gpu": torch.cuda.get_device_name(0), | |
| "vram_total_gb": round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 2), | |
| "vram_allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2), | |
| "model_loaded": True | |
| }) | |
| @app.route('/reconstruct', methods=['POST']) | |
| @require_api_key | |
| def reconstruct_3d(): | |
| try: | |
| start_time = datetime.now() | |
| # Handle both JSON and form data | |
| if request.is_json: | |
| data = request.json | |
| image_data = base64.b64decode(data['image']) | |
| else: | |
| # Handle file upload | |
| if 'image' not in request.files: | |
| return jsonify({"success": False, "error": "No image provided"}), 400 | |
| image_file = request.files['image'] | |
| image_data = image_file.read() | |
| # Load image | |
| image = Image.open(io.BytesIO(image_data)) | |
| print(f"Processing image: {image.size}, mode: {image.mode}") | |
| # Convert PIL Image to numpy array for inference | |
| image_array = np.array(image) | |
| # Handle mask if provided | |
| mask = None | |
| if request.is_json and 'mask' in data: | |
| mask_data = base64.b64decode(data['mask']) | |
| mask = Image.open(io.BytesIO(mask_data)) | |
| elif 'mask' in request.files: | |
| mask_data = request.files['mask'].read() | |
| mask = Image.open(io.BytesIO(mask_data)) | |
| # If no mask provided, create a white mask (full object) | |
| if mask is None: | |
| mask = np.ones((image_array.shape[0], image_array.shape[1]), dtype=np.float32) | |
| else: | |
| # Convert mask to numpy array if it's a PIL Image | |
| if hasattr(mask, 'size'): | |
| mask = np.array(mask) | |
| # Ensure mask is 2D float32 | |
| if mask.ndim == 3: | |
| mask = mask[:, :, 0] # Take first channel if RGB | |
| mask = mask.astype(np.float32) | |
| # Get seed | |
| seed = 42 | |
| if request.is_json: | |
| seed = data.get('seed', 42) | |
| elif 'seed' in request.form: | |
| seed = int(request.form['seed']) | |
| # Run inference | |
| print(f"Running inference with seed={seed}...") | |
| output = inference_model(image_array, mask, seed=seed) | |
| # Save output | |
| os.makedirs('/tmp/outputs', exist_ok=True) | |
| timestamp = int(datetime.now().timestamp()) | |
| output_path = f"/tmp/outputs/output_{seed}_{timestamp}.ply" | |
| output["gs"].save_ply(output_path) | |
| # Get file size | |
| file_size = os.path.getsize(output_path) | |
| # Return as base64 | |
| with open(output_path, 'rb') as f: | |
| ply_data = base64.b64encode(f.read()).decode() | |
| elapsed = (datetime.now() - start_time).total_seconds() | |
| gpu_memory_gb = torch.cuda.max_memory_allocated() / 1024**3 | |
| print(f"✓ Completed in {elapsed:.2f}s, GPU memory: {gpu_memory_gb:.2f}GB") | |
| return jsonify({ | |
| "success": True, | |
| "ply_file": ply_data, | |
| "file_size_mb": round(file_size / 1024**2, 2), | |
| "format": "gaussian_splat", | |
| "processing_time_seconds": round(elapsed, 2), | |
| "gpu_memory_used_gb": round(gpu_memory_gb, 2), | |
| "image_size": [image_array.shape[1], image_array.shape[0]], | |
| "seed": seed | |
| }) | |
| except Exception as e: | |
| print(f"ERROR: {e}") | |
| traceback.print_exc() | |
| return jsonify({ | |
| "success": False, | |
| "error": str(e), | |
| "traceback": traceback.format_exc() | |
| }), 500 | |
| @app.route('/stats', methods=['GET']) | |
| @require_api_key | |
| def gpu_stats(): | |
| return jsonify({ | |
| "gpu_name": torch.cuda.get_device_name(0), | |
| "total_memory_gb": round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 2), | |
| "allocated_memory_gb": round(torch.cuda.memory_allocated() / 1024**3, 2), | |
| "cached_memory_gb": round(torch.cuda.memory_reserved() / 1024**3, 2), | |
| "available_memory_gb": round((torch.cuda.get_device_properties(0).total_memory - | |
| torch.cuda.memory_allocated()) / 1024**3, 2) | |
| }) | |
| @app.route('/', methods=['GET']) | |
| def index(): | |
| return jsonify({ | |
| "service": "SAM 3D Objects API", | |
| "version": "1.0", | |
| "gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU", | |
| "endpoints": { | |
| "/health": "GET - Health check", | |
| "/stats": "GET - GPU statistics", | |
| "/reconstruct": "POST - 3D reconstruction (JSON with base64 image or multipart form-data)" | |
| } | |
| }) | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=8000, debug=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment