|
#!/usr/bin/env python3 |
|
""" |
|
MedGemma MRT Analysis Script - Batch Processing for DICOM Images |
|
|
|
This script analyzes MRI DICOM images using Google's MedGemma model via Vertex AI. |
|
It was created as an experiment to see if AI could provide a "second opinion" on |
|
medical imaging - spoiler: it can't replace doctors. Not even close. |
|
|
|
IMPORTANT: This is for educational purposes only. Never rely on AI for medical diagnosis! |
|
|
|
Prerequisites: |
|
- Google Cloud account with Vertex AI API enabled |
|
- gcloud CLI installed and authenticated |
|
- A deployed MedGemma DICOM endpoint (see deployment instructions) |
|
- DICOM files uploaded to a PUBLIC GCS bucket |
|
|
|
Usage: |
|
python analyze_mrt.py <dataset> # Test mode (1 batch) |
|
python analyze_mrt.py <dataset> all # Process all batches |
|
python analyze_mrt.py <dataset> finalreport # Generate summary from saved batches |
|
|
|
Author: Created during a "let's see if AI can read MRI scans" experiment |
|
Result: AI hallucinated tumors, cement from surgeries that never happened, |
|
and couldn't even count vertebrae correctly. Use at your own risk! |
|
""" |
|
|
|
import json |
|
import subprocess |
|
import sys |
|
from pathlib import Path |
|
from datetime import datetime |
|
import time |
|
|
|
# ============================================================================= |
|
# CONFIGURATION - UPDATE THESE VALUES FOR YOUR DEPLOYMENT |
|
# ============================================================================= |
|
|
|
# Your Google Cloud project number (not the project ID!) |
|
PROJECT_NUMBER = "YOUR_PROJECT_NUMBER" |
|
|
|
# Region where your endpoint is deployed |
|
REGION = "us-central1" |
|
|
|
# Endpoint configuration - get these after deploying MedGemma in Vertex AI |
|
# IMPORTANT: Use the DICOM variant (medgemma-X-it-dicom), not the regular one! |
|
ENDPOINT_ID = "YOUR_ENDPOINT_ID" |
|
DEDICATED_DOMAIN = "YOUR_ENDPOINT.us-central1-XXXXXX.prediction.vertexai.goog" |
|
|
|
# Dataset configurations - customize for your data |
|
DATASETS = { |
|
"dataset1": { |
|
"bucket": "your-bucket-name", |
|
"path": "your-dicom-folder", |
|
"description": "MRI Dataset 1", |
|
"examination": "MRI examination type" |
|
}, |
|
# Add more datasets as needed |
|
} |
|
|
|
# Batch size - depends on your deployment configuration |
|
# Default is 16, but you can increase to 125 via Console UI deployment |
|
BATCH_SIZE = 16 |
|
|
|
# Base directory for results |
|
BASE_RESULTS_DIR = Path(__file__).parent |
|
|
|
# ============================================================================= |
|
# PATIENT INFORMATION TEMPLATE (symptoms only, no diagnosis!) |
|
# ============================================================================= |
|
|
|
PATIENT_INFO = """ |
|
PATIENT DATA: |
|
- Age: XX years |
|
- Sex: X |
|
- Relevant medical history: ... |
|
|
|
CURRENT SYMPTOMS: |
|
- Pain location: ... |
|
- Pain characteristics: ... |
|
- Functional limitations: ... |
|
""" |
|
|
|
# ============================================================================= |
|
# PROMPTS |
|
# ============================================================================= |
|
|
|
def get_batch_prompt(batch_num: int, total_batches: int, examination: str): |
|
""" |
|
Generate prompt for individual batch analysis. |
|
|
|
Key insight: Keep prompts neutral! Don't hint at what you expect to find. |
|
The MedGemma prompting guide recommends avoiding leading questions. |
|
""" |
|
return f"""You are an expert radiologist specialized in musculoskeletal MRI. |
|
|
|
CLINICAL INFORMATION: |
|
- Patient demographics and history here |
|
- Current symptoms |
|
|
|
EXAMINATION: {examination} |
|
IMAGE SET: {batch_num} of {total_batches} |
|
|
|
Examine these MRI slices carefully. Describe all anatomical structures visible and report any abnormal or noteworthy findings. Provide a thorough, formal radiological report. |
|
|
|
Format your response as: |
|
FINDINGS: (detailed description of all observations) |
|
IMPRESSION: (summary of significant findings)""" |
|
|
|
|
|
def get_summary_prompt(batch_results_text: str): |
|
"""Generate prompt for final summary report combining all batches.""" |
|
return f"""You are an experienced radiologist. |
|
|
|
{PATIENT_INFO} |
|
|
|
EXAMINATION: MRI scan |
|
|
|
I have shown you the MRI images in multiple batches. Here are the findings from all slices: |
|
|
|
{batch_results_text} |
|
|
|
TASK: Create a COMPLETE RADIOLOGICAL REPORT summarizing all findings. |
|
|
|
Structure the report as follows: |
|
1. FINDINGS (detailed description of all pathologies) |
|
2. IMPRESSION (summary of main findings) |
|
3. RECOMMENDATIONS (clinical recommendations) |
|
|
|
Make sure not to miss any findings and consolidate duplicates.""" |
|
|
|
|
|
# ============================================================================= |
|
# FUNCTIONS |
|
# ============================================================================= |
|
|
|
def get_auth_token(): |
|
""" |
|
Get Google Cloud authentication token via gcloud CLI. |
|
|
|
Make sure you're authenticated: gcloud auth login |
|
""" |
|
result = subprocess.run( |
|
["gcloud", "auth", "print-access-token"], |
|
capture_output=True, |
|
text=True |
|
) |
|
return result.stdout.strip() |
|
|
|
|
|
def get_gcs_files(bucket: str, path: str): |
|
""" |
|
List all DICOM files from GCS bucket. |
|
|
|
IMPORTANT: Your bucket must be publicly readable! |
|
The MedGemma endpoint fetches images via HTTP URLs. |
|
""" |
|
result = subprocess.run( |
|
["gcloud", "storage", "ls", f"gs://{bucket}/{path}/"], |
|
capture_output=True, |
|
text=True |
|
) |
|
|
|
all_files = sorted([ |
|
line.strip().split("/")[-1] |
|
for line in result.stdout.strip().split("\n") |
|
if line.strip() |
|
]) |
|
|
|
return all_files |
|
|
|
|
|
def analyze_batch(bucket: str, path: str, filenames: list, prompt: str, token: str): |
|
""" |
|
Send a batch of DICOM images to MedGemma for analysis. |
|
|
|
Key learnings: |
|
- Use HTTP URLs, not gs:// URIs |
|
- Text prompt must come FIRST, then images |
|
- One corrupted file = entire batch fails |
|
- Uses curl because Python requests had SSL issues |
|
""" |
|
api_url = f"https://{DEDICATED_DOMAIN}/v1/projects/{PROJECT_NUMBER}/locations/{REGION}/endpoints/{ENDPOINT_ID}:predict" |
|
|
|
# Build content array: text first, then images |
|
content_parts = [{"type": "text", "text": prompt}] |
|
|
|
for filename in filenames: |
|
http_url = f"https://storage.googleapis.com/{bucket}/{path}/{filename}" |
|
content_parts.append({ |
|
"type": "image_url", |
|
"image_url": {"url": http_url} |
|
}) |
|
|
|
# OpenAI-style chat completion format (required for DICOM models) |
|
payload = { |
|
"instances": [{ |
|
"messages": [{"role": "user", "content": content_parts}], |
|
"max_tokens": 2048 |
|
}] |
|
} |
|
|
|
# Write payload to temp file (curl reads from file) |
|
payload_file = Path(__file__).parent / "temp_payload.json" |
|
with open(payload_file, "w") as f: |
|
json.dump(payload, f) |
|
|
|
# Use curl - more reliable than Python requests for this endpoint |
|
result = subprocess.run( |
|
[ |
|
"curl", "-s", "-X", "POST", api_url, |
|
"-H", f"Authorization: Bearer {token}", |
|
"-H", "Content-Type: application/json", |
|
"-d", f"@{payload_file}", |
|
"--max-time", "300" |
|
], |
|
capture_output=True, |
|
text=True |
|
) |
|
|
|
payload_file.unlink() # Clean up temp file |
|
|
|
# Parse response |
|
try: |
|
response = json.loads(result.stdout) |
|
if "error" in response: |
|
return f"ERROR: {response['error']}" |
|
if "predictions" in response and "choices" in response["predictions"]: |
|
return response["predictions"]["choices"][0]["message"]["content"] |
|
return f"Unexpected response: {result.stdout[:500]}" |
|
except json.JSONDecodeError: |
|
return f"JSON parse error: {result.stdout[:500]}" |
|
|
|
|
|
def save_batch_result(batch_num: int, result: str, filenames: list, results_dir: Path): |
|
"""Save individual batch result to JSON file.""" |
|
results_dir.mkdir(exist_ok=True) |
|
|
|
batch_file = results_dir / f"batch_{batch_num:02d}.json" |
|
data = { |
|
"batch_num": batch_num, |
|
"timestamp": datetime.now().isoformat(), |
|
"files": filenames, |
|
"result": result |
|
} |
|
|
|
with open(batch_file, "w", encoding="utf-8") as f: |
|
json.dump(data, f, ensure_ascii=False, indent=2) |
|
|
|
return batch_file |
|
|
|
|
|
def load_all_batch_results(results_dir: Path): |
|
"""Load all saved batch results from directory.""" |
|
if not results_dir.exists(): |
|
return [] |
|
|
|
results = [] |
|
for batch_file in sorted(results_dir.glob("batch_*.json")): |
|
with open(batch_file, "r", encoding="utf-8") as f: |
|
data = json.load(f) |
|
results.append(data) |
|
|
|
return results |
|
|
|
|
|
def create_final_report(batch_results: list, token: str): |
|
""" |
|
Create final summary report from all batch results. |
|
|
|
This sends all the individual batch findings back to MedGemma |
|
to generate a consolidated report. |
|
""" |
|
api_url = f"https://{DEDICATED_DOMAIN}/v1/projects/{PROJECT_NUMBER}/locations/{REGION}/endpoints/{ENDPOINT_ID}:predict" |
|
|
|
# Format all batch results |
|
results_text = "\n\n".join([ |
|
f"=== BATCH {r['batch_num']} ===\n{r['result']}" |
|
for r in batch_results |
|
]) |
|
|
|
prompt = get_summary_prompt(results_text) |
|
|
|
payload = { |
|
"instances": [{ |
|
"messages": [{"role": "user", "content": prompt}], |
|
"max_tokens": 4096 |
|
}] |
|
} |
|
|
|
payload_file = Path(__file__).parent / "temp_payload.json" |
|
with open(payload_file, "w") as f: |
|
json.dump(payload, f) |
|
|
|
print("Sending request for final report...") |
|
|
|
result = subprocess.run( |
|
[ |
|
"curl", "-s", "-X", "POST", api_url, |
|
"-H", f"Authorization: Bearer {token}", |
|
"-H", "Content-Type: application/json", |
|
"-d", f"@{payload_file}", |
|
"--max-time", "300" |
|
], |
|
capture_output=True, |
|
text=True |
|
) |
|
|
|
payload_file.unlink() |
|
|
|
try: |
|
response = json.loads(result.stdout) |
|
if "error" in response: |
|
return f"ERROR: {response['error']}" |
|
if "predictions" in response and "choices" in response["predictions"]: |
|
return response["predictions"]["choices"][0]["message"]["content"] |
|
return f"Unexpected response: {result.stdout[:500]}" |
|
except json.JSONDecodeError: |
|
return f"JSON parse error: {result.stdout[:500]}" |
|
|
|
|
|
def save_final_report(summary: str, batch_results: list): |
|
"""Save the final consolidated report.""" |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
filename = f"mrt_FINAL_{timestamp}.txt" |
|
output_path = Path(__file__).parent / filename |
|
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
f.write("=" * 80 + "\n") |
|
f.write("MEDGEMMA MRI ANALYSIS - FINAL REPORT\n") |
|
f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") |
|
f.write(f"Batches analyzed: {len(batch_results)}\n") |
|
f.write(f"Total slices: {sum(len(r['files']) for r in batch_results)}\n") |
|
f.write("=" * 80 + "\n\n") |
|
|
|
f.write("PATIENT SYMPTOMS:\n") |
|
f.write("-" * 40 + "\n") |
|
f.write(PATIENT_INFO + "\n\n") |
|
|
|
f.write("=" * 80 + "\n") |
|
f.write("FINAL REPORT (AI-GENERATED):\n") |
|
f.write("-" * 40 + "\n") |
|
f.write(summary + "\n\n") |
|
|
|
f.write("=" * 80 + "\n") |
|
f.write("DETAIL: INDIVIDUAL BATCH ANALYSES\n") |
|
f.write("=" * 80 + "\n") |
|
for r in batch_results: |
|
f.write(f"\n--- Batch {r['batch_num']} ({len(r['files'])} images) ---\n") |
|
f.write(r['result'] + "\n") |
|
|
|
f.write("\n" + "=" * 80 + "\n") |
|
f.write("DISCLAIMER: AI-generated report - must be validated by a physician!\n") |
|
f.write("=" * 80 + "\n") |
|
|
|
return output_path |
|
|
|
|
|
def save_batch_only_report(batch_results: list): |
|
"""Save batch results without final summary.""" |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
filename = f"mrt_batches_{timestamp}.txt" |
|
output_path = Path(__file__).parent / filename |
|
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
f.write("=" * 80 + "\n") |
|
f.write("MEDGEMMA MRI ANALYSIS - BATCH RESULTS\n") |
|
f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") |
|
f.write(f"Batches analyzed: {len(batch_results)}\n") |
|
f.write("=" * 80 + "\n") |
|
f.write("NOTE: Run 'python analyze_mrt.py <dataset> finalreport' for summary\n") |
|
f.write("=" * 80 + "\n\n") |
|
|
|
for r in batch_results: |
|
f.write(f"\n{'='*60}\n") |
|
f.write(f"BATCH {r['batch_num']} ({len(r['files'])} images)\n") |
|
f.write(f"{'='*60}\n") |
|
f.write(r['result'] + "\n") |
|
|
|
f.write("\n" + "=" * 80 + "\n") |
|
f.write("DISCLAIMER: AI-generated report - must be validated by a physician!\n") |
|
f.write("=" * 80 + "\n") |
|
|
|
return output_path |
|
|
|
|
|
# ============================================================================= |
|
# MAIN FUNCTIONS |
|
# ============================================================================= |
|
|
|
def run_batch_analysis(dataset_name: str, run_all: bool): |
|
"""Execute the batch analysis.""" |
|
|
|
if dataset_name not in DATASETS: |
|
print(f"Unknown dataset: {dataset_name}") |
|
print(f"Available: {', '.join(DATASETS.keys())}") |
|
return |
|
|
|
dataset = DATASETS[dataset_name] |
|
results_dir = BASE_RESULTS_DIR / f"batch_results_{dataset_name}" |
|
|
|
print("=" * 60) |
|
print(f"MEDGEMMA MRI ANALYSIS: {dataset['description']}") |
|
if run_all: |
|
print("Mode: ALL BATCHES") |
|
else: |
|
print("Mode: TEST (1 batch only)") |
|
print("=" * 60) |
|
|
|
# Get auth token |
|
print("\nGetting auth token...") |
|
token = get_auth_token() |
|
print("Token obtained") |
|
|
|
# Get all DICOM files from GCS |
|
print(f"\nFetching DICOM files from {dataset['bucket']}/{dataset['path']}...") |
|
all_files = get_gcs_files(dataset['bucket'], dataset['path']) |
|
|
|
if not all_files: |
|
print("No DICOM files found!") |
|
return |
|
|
|
print(f"Found: {len(all_files)} files") |
|
|
|
# Split into batches |
|
all_batches = [all_files[i:i+BATCH_SIZE] for i in range(0, len(all_files), BATCH_SIZE)] |
|
|
|
if run_all: |
|
batches = all_batches |
|
print(f"{len(batches)} batches of max {BATCH_SIZE} images each") |
|
else: |
|
batches = all_batches[:1] |
|
print(f"TEST: Only batch 1 of {len(all_batches)}") |
|
|
|
# Clear previous results for new run |
|
if results_dir.exists(): |
|
for f in results_dir.glob("batch_*.json"): |
|
f.unlink() |
|
|
|
# Analyze each batch |
|
print("\n" + "=" * 60) |
|
print("STARTING BATCH ANALYSIS...") |
|
print("=" * 60) |
|
|
|
batch_results = [] |
|
for i, batch in enumerate(batches): |
|
batch_num = i + 1 |
|
print(f"\nBatch {batch_num}/{len(batches)} ({len(batch)} images)...") |
|
|
|
prompt = get_batch_prompt(batch_num, len(all_batches), dataset['examination']) |
|
result = analyze_batch(dataset['bucket'], dataset['path'], batch, prompt, token) |
|
|
|
# Save result |
|
save_batch_result(batch_num, result, batch, results_dir) |
|
batch_results.append({ |
|
"batch_num": batch_num, |
|
"files": batch, |
|
"result": result |
|
}) |
|
|
|
# Preview |
|
preview = result[:150].replace('\n', ' ') |
|
print(f" Result: {preview}...") |
|
|
|
# Brief pause between requests |
|
if i < len(batches) - 1: |
|
time.sleep(2) |
|
|
|
# Save batch report |
|
output_path = save_batch_only_report(batch_results) |
|
|
|
print("\n" + "=" * 60) |
|
print("BATCH ANALYSIS COMPLETE") |
|
print(f" {len(batch_results)} batches analyzed") |
|
print(f" Results in: {results_dir}/") |
|
print(f" Report: {output_path}") |
|
print("=" * 60) |
|
|
|
if run_all: |
|
print(f"\nFor final report: python analyze_mrt.py {dataset_name} finalreport") |
|
else: |
|
print(f"\nFor all batches: python analyze_mrt.py {dataset_name} all") |
|
|
|
print("\nIMPORTANT: Delete your endpoint after use to save costs!") |
|
print("You're being charged per second!") |
|
|
|
|
|
def run_final_report(dataset_name: str): |
|
"""Generate final summary from saved batch results.""" |
|
|
|
if dataset_name not in DATASETS: |
|
print(f"Unknown dataset: {dataset_name}") |
|
print(f"Available: {', '.join(DATASETS.keys())}") |
|
return |
|
|
|
dataset = DATASETS[dataset_name] |
|
results_dir = BASE_RESULTS_DIR / f"batch_results_{dataset_name}" |
|
|
|
print("=" * 60) |
|
print(f"GENERATING FINAL REPORT: {dataset['description']}") |
|
print("=" * 60) |
|
|
|
# Load batch results |
|
batch_results = load_all_batch_results(results_dir) |
|
|
|
if not batch_results: |
|
print("No batch results found!") |
|
print(f"First run 'python analyze_mrt.py {dataset_name} all'") |
|
return |
|
|
|
print(f"Found: {len(batch_results)} batch results") |
|
|
|
# Get auth token |
|
print("\nGetting auth token...") |
|
token = get_auth_token() |
|
print("Token obtained") |
|
|
|
# Create final report |
|
print("\n" + "=" * 60) |
|
print("CREATING FINAL REPORT...") |
|
print("=" * 60) |
|
|
|
summary = create_final_report(batch_results, token) |
|
|
|
# Save |
|
output_path = save_final_report(summary, batch_results) |
|
|
|
# Show preview |
|
print("\n" + "=" * 60) |
|
print("FINAL REPORT (Preview):") |
|
print("=" * 60) |
|
preview = summary[:800] + "..." if len(summary) > 800 else summary |
|
print(preview) |
|
print(f"\nFull report saved to: {output_path}") |
|
|
|
print("\n" + "=" * 60) |
|
print("FINAL REPORT CREATED") |
|
print("=" * 60) |
|
print("\nIMPORTANT: Delete your endpoint and buckets after use!") |
|
|
|
|
|
# ============================================================================= |
|
# MAIN |
|
# ============================================================================= |
|
|
|
def print_usage(): |
|
print("Usage:") |
|
print(" python analyze_mrt.py <dataset> # Test (1 batch)") |
|
print(" python analyze_mrt.py <dataset> all # All batches") |
|
print(" python analyze_mrt.py <dataset> finalreport # Generate final report") |
|
print("") |
|
print("Datasets:") |
|
for name, config in DATASETS.items(): |
|
print(f" {name:8} - {config['description']}") |
|
|
|
|
|
def main(): |
|
if len(sys.argv) < 2: |
|
print("Dataset not specified!") |
|
print_usage() |
|
return |
|
|
|
dataset_name = sys.argv[1].lower() |
|
|
|
if dataset_name in ["help", "-h", "--help"]: |
|
print_usage() |
|
return |
|
|
|
if dataset_name not in DATASETS: |
|
print(f"Unknown dataset: {dataset_name}") |
|
print_usage() |
|
return |
|
|
|
# Second parameter (optional) |
|
if len(sys.argv) > 2: |
|
arg = sys.argv[2].lower() |
|
if arg == "all": |
|
run_batch_analysis(dataset_name, run_all=True) |
|
elif arg == "finalreport": |
|
run_final_report(dataset_name) |
|
else: |
|
print(f"Unknown mode: {arg}") |
|
print_usage() |
|
else: |
|
# Default: test mode (1 batch) |
|
run_batch_analysis(dataset_name, run_all=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |