Skip to content

Instantly share code, notes, and snippets.

@MohaElder
Created November 22, 2025 01:38
Show Gist options
  • Select an option

  • Save MohaElder/5c4c2eb2f5f3a5e47a98c0c81874c780 to your computer and use it in GitHub Desktop.

Select an option

Save MohaElder/5c4c2eb2f5f3a5e47a98c0c81874c780 to your computer and use it in GitHub Desktop.
A simple backend service to call sam3d model inference
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import torch
import sys
import os
sys.path.append("notebook")
from inference import Inference
from PIL import Image
import base64
import io
from datetime import datetime
import traceback
from functools import wraps
import numpy as np
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# API Security - Set API_SECRET environment variable
API_SECRET = os.environ.get('API_SECRET', None)
def require_api_key(f):
"""Decorator to require API key in X-API-Key header"""
@wraps(f)
def decorated_function(*args, **kwargs):
if API_SECRET is None:
# No security configured, allow access
return f(*args, **kwargs)
provided_key = request.headers.get('X-API-Key')
if not provided_key:
return jsonify({
"success": False,
"error": "Missing X-API-Key header"
}), 401
if provided_key != API_SECRET:
return jsonify({
"success": False,
"error": "Invalid API key"
}), 403
return f(*args, **kwargs)
return decorated_function
# Initialize model
print("Loading SAM 3D Objects model...")
config_path = "checkpoints/hf/checkpoints/pipeline.yaml"
if not os.path.exists(config_path):
print(f"ERROR: Config file not found at {config_path}")
print("Available files in checkpoints/hf/checkpoints/:")
os.system("ls -la checkpoints/hf/checkpoints/")
sys.exit(1)
try:
inference_model = Inference(config_path, compile=False)
# Enable optimizations (works for all RTX GPUs)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')
print(f"✓ Model loaded on: {torch.cuda.get_device_name(0)}")
print(f"✓ Total VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
except Exception as e:
print(f"ERROR loading model: {e}")
traceback.print_exc()
sys.exit(1)
@app.route('/health', methods=['GET'])
@require_api_key
def health():
return jsonify({
"status": "healthy",
"gpu": torch.cuda.get_device_name(0),
"vram_total_gb": round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 2),
"vram_allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2),
"model_loaded": True
})
@app.route('/reconstruct', methods=['POST'])
@require_api_key
def reconstruct_3d():
try:
start_time = datetime.now()
# Handle both JSON and form data
if request.is_json:
data = request.json
image_data = base64.b64decode(data['image'])
else:
# Handle file upload
if 'image' not in request.files:
return jsonify({"success": False, "error": "No image provided"}), 400
image_file = request.files['image']
image_data = image_file.read()
# Load image
image = Image.open(io.BytesIO(image_data))
print(f"Processing image: {image.size}, mode: {image.mode}")
# Convert PIL Image to numpy array for inference
image_array = np.array(image)
# Handle mask if provided
mask = None
if request.is_json and 'mask' in data:
mask_data = base64.b64decode(data['mask'])
mask = Image.open(io.BytesIO(mask_data))
elif 'mask' in request.files:
mask_data = request.files['mask'].read()
mask = Image.open(io.BytesIO(mask_data))
# If no mask provided, create a white mask (full object)
if mask is None:
mask = np.ones((image_array.shape[0], image_array.shape[1]), dtype=np.float32)
else:
# Convert mask to numpy array if it's a PIL Image
if hasattr(mask, 'size'):
mask = np.array(mask)
# Ensure mask is 2D float32
if mask.ndim == 3:
mask = mask[:, :, 0] # Take first channel if RGB
mask = mask.astype(np.float32)
# Get seed
seed = 42
if request.is_json:
seed = data.get('seed', 42)
elif 'seed' in request.form:
seed = int(request.form['seed'])
# Run inference
print(f"Running inference with seed={seed}...")
output = inference_model(image_array, mask, seed=seed)
# Save output
os.makedirs('/tmp/outputs', exist_ok=True)
timestamp = int(datetime.now().timestamp())
output_path = f"/tmp/outputs/output_{seed}_{timestamp}.ply"
output["gs"].save_ply(output_path)
# Get file size
file_size = os.path.getsize(output_path)
# Return as base64
with open(output_path, 'rb') as f:
ply_data = base64.b64encode(f.read()).decode()
elapsed = (datetime.now() - start_time).total_seconds()
gpu_memory_gb = torch.cuda.max_memory_allocated() / 1024**3
print(f"✓ Completed in {elapsed:.2f}s, GPU memory: {gpu_memory_gb:.2f}GB")
return jsonify({
"success": True,
"ply_file": ply_data,
"file_size_mb": round(file_size / 1024**2, 2),
"format": "gaussian_splat",
"processing_time_seconds": round(elapsed, 2),
"gpu_memory_used_gb": round(gpu_memory_gb, 2),
"image_size": [image_array.shape[1], image_array.shape[0]],
"seed": seed
})
except Exception as e:
print(f"ERROR: {e}")
traceback.print_exc()
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
@app.route('/stats', methods=['GET'])
@require_api_key
def gpu_stats():
return jsonify({
"gpu_name": torch.cuda.get_device_name(0),
"total_memory_gb": round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 2),
"allocated_memory_gb": round(torch.cuda.memory_allocated() / 1024**3, 2),
"cached_memory_gb": round(torch.cuda.memory_reserved() / 1024**3, 2),
"available_memory_gb": round((torch.cuda.get_device_properties(0).total_memory -
torch.cuda.memory_allocated()) / 1024**3, 2)
})
@app.route('/', methods=['GET'])
def index():
return jsonify({
"service": "SAM 3D Objects API",
"version": "1.0",
"gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU",
"endpoints": {
"/health": "GET - Health check",
"/stats": "GET - GPU statistics",
"/reconstruct": "POST - 3D reconstruction (JSON with base64 image or multipart form-data)"
}
})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8000, debug=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment