Last active
September 4, 2025 15:57
-
-
Save madhurprash/c40b3bb917acce599d50c46db2176247 to your computer and use it in GitHub Desktop.
Use Custom Model Import (CMI) on Bedrock
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| AWS Bedrock Custom Model Import - Qwen Models | |
| """ | |
| import os | |
| import json | |
| import boto3 | |
| import base64 | |
| import time | |
| from typing import Dict, List, Optional | |
| from huggingface_hub import snapshot_download | |
| from botocore.exceptions import ClientError | |
| class QwenBedrockDeployer: | |
| def __init__(self, aws_region: str = "us-east-1"): | |
| """Initialize the deployer with AWS clients.""" | |
| self.region = aws_region | |
| self.bedrock_client = boto3.client('bedrock', region_name=aws_region) | |
| self.bedrock_runtime = boto3.client('bedrock-runtime', region_name=aws_region) | |
| self.s3_client = boto3.client('s3', region_name=aws_region) | |
| self.iam_client = boto3.client('iam', region_name=aws_region) | |
| def create_bedrock_service_role(self, role_name: str = "BedrockCustomModelImportRole") -> str: | |
| """Create an IAM role with proper permissions for Bedrock model import.""" | |
| print(f"Creating IAM role: {role_name}") | |
| trust_policy = { | |
| "Version": "2012-10-17", | |
| "Statement": [{ | |
| "Effect": "Allow", | |
| "Principal": {"Service": "bedrock.amazonaws.com"}, | |
| "Action": "sts:AssumeRole" | |
| }] | |
| } | |
| s3_policy = { | |
| "Version": "2012-10-17", | |
| "Statement": [{ | |
| "Effect": "Allow", | |
| "Action": ["s3:GetObject", "s3:ListBucket"], | |
| "Resource": ["arn:aws:s3:::*", "arn:aws:s3:::*/*"] | |
| }] | |
| } | |
| try: | |
| role_response = self.iam_client.create_role( | |
| RoleName=role_name, | |
| AssumeRolePolicyDocument=json.dumps(trust_policy), | |
| Description="Role for Bedrock Custom Model Import" | |
| ) | |
| role_arn = role_response['Role']['Arn'] | |
| self.iam_client.put_role_policy( | |
| RoleName=role_name, | |
| PolicyName=f"{role_name}-S3Policy", | |
| PolicyDocument=json.dumps(s3_policy) | |
| ) | |
| print(f"IAM role created: {role_arn}") | |
| print("Waiting 10 seconds for role propagation...") | |
| time.sleep(10) | |
| return role_arn | |
| except ClientError as e: | |
| if e.response['Error']['Code'] == 'EntityAlreadyExists': | |
| role_response = self.iam_client.get_role(RoleName=role_name) | |
| role_arn = role_response['Role']['Arn'] | |
| print(f"Using existing IAM role: {role_arn}") | |
| return role_arn | |
| else: | |
| raise Exception(f"Error creating IAM role: {str(e)}") | |
| def verify_s3_bucket(self, bucket_name: str) -> bool: | |
| """Verify if S3 bucket exists and is accessible.""" | |
| try: | |
| self.s3_client.head_bucket(Bucket=bucket_name) | |
| print(f"S3 bucket '{bucket_name}' is accessible") | |
| return True | |
| except ClientError as e: | |
| error_code = e.response['Error']['Code'] | |
| if error_code == '404': | |
| print(f"S3 bucket '{bucket_name}' does not exist") | |
| elif error_code == '403': | |
| print(f"Access denied to S3 bucket '{bucket_name}'") | |
| else: | |
| print(f"Error accessing S3 bucket: {str(e)}") | |
| return False | |
| def validate_s3_model(self, s3_uri: str) -> bool: | |
| """Validate if model files exist at the given S3 location.""" | |
| try: | |
| # Parse S3 URI | |
| if not s3_uri.startswith('s3://'): | |
| print("Invalid S3 URI format. Must start with 's3://'") | |
| return False | |
| parts = s3_uri.replace('s3://', '').split('/', 1) | |
| if len(parts) != 2: | |
| print("Invalid S3 URI format") | |
| return False | |
| bucket_name, prefix = parts[0], parts[1].rstrip('/') | |
| print(f"Checking for model files in s3://{bucket_name}/{prefix}") | |
| # List objects | |
| response = self.s3_client.list_objects_v2( | |
| Bucket=bucket_name, | |
| Prefix=prefix, | |
| MaxKeys=100 | |
| ) | |
| if 'Contents' not in response: | |
| print("No files found at specified S3 location") | |
| return False | |
| files = [obj['Key'].split('/')[-1] for obj in response['Contents']] | |
| # Check for essential files | |
| config_found = any('config.json' in f for f in files) | |
| model_found = any(f.endswith(('.bin', '.safetensors')) and 'model' in f for f in files) | |
| if config_found and model_found: | |
| print(f"Found {len(files)} model files at S3 location") | |
| return True | |
| else: | |
| missing = [] | |
| if not config_found: | |
| missing.append('config.json') | |
| if not model_found: | |
| missing.append('model weights') | |
| print(f"Missing essential files: {', '.join(missing)}") | |
| return False | |
| except Exception as e: | |
| print(f"Error validating S3 model: {str(e)}") | |
| return False | |
| def download_model_from_huggingface(self, model_id: str, local_dir: str) -> str: | |
| """Download a model from Hugging Face.""" | |
| print(f"Downloading model {model_id} to {local_dir}...") | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| os.makedirs(local_dir, exist_ok=True) | |
| try: | |
| snapshot_download( | |
| repo_id=model_id, | |
| local_dir=local_dir, | |
| resume_download=True | |
| ) | |
| print(f"Model downloaded successfully to {local_dir}") | |
| return local_dir | |
| except Exception as e: | |
| raise Exception(f"Error downloading model: {str(e)}") | |
| def upload_model_to_s3(self, local_dir: str, bucket_name: str, s3_prefix: str) -> str: | |
| """Upload model files to S3.""" | |
| print(f"Uploading model to s3://{bucket_name}/{s3_prefix}") | |
| try: | |
| total_files = sum(len(files) for _, _, files in os.walk(local_dir)) | |
| uploaded_files = 0 | |
| for root, dirs, files in os.walk(local_dir): | |
| for file in files: | |
| local_path = os.path.join(root, file) | |
| relative_path = os.path.relpath(local_path, local_dir) | |
| s3_path = f"{s3_prefix}/{relative_path}".replace("\\", "/") | |
| uploaded_files += 1 | |
| print(f"Uploading {relative_path} ({uploaded_files}/{total_files})...") | |
| self.s3_client.upload_file( | |
| local_path, | |
| bucket_name, | |
| s3_path, | |
| ExtraArgs={'Metadata': {'source': 'huggingface'}} | |
| ) | |
| s3_uri = f"s3://{bucket_name}/{s3_prefix}/" | |
| print(f"Upload complete: {s3_uri}") | |
| return s3_uri | |
| except Exception as e: | |
| raise Exception(f"Error uploading model: {str(e)}") | |
| def create_model_import_job(self, job_name: str, model_name: str, | |
| s3_uri: str, role_arn: str) -> str: | |
| """Create a Bedrock model import job.""" | |
| print(f"Creating model import job: {job_name}") | |
| try: | |
| response = self.bedrock_client.create_model_import_job( | |
| jobName=job_name, | |
| importedModelName=model_name, | |
| roleArn=role_arn, | |
| modelDataSource={ | |
| 's3DataSource': {'s3Uri': s3_uri} | |
| }, | |
| jobTags=[ | |
| {'key': 'Source', 'value': 'HuggingFace'}, | |
| {'key': 'ModelType', 'value': 'Qwen'} | |
| ] | |
| ) | |
| job_arn = response['jobArn'] | |
| print(f"Import job created: {job_arn}") | |
| return job_arn | |
| except Exception as e: | |
| raise Exception(f"Error creating import job: {str(e)}") | |
| def get_job_status(self, job_name: str) -> Dict: | |
| """Check the status of a model import job.""" | |
| try: | |
| response = self.bedrock_client.get_model_import_job(jobIdentifier=job_name) | |
| return { | |
| 'status': response['status'], | |
| 'modelArn': response.get('importedModelArn'), | |
| 'statusMessage': response.get('statusMessage', ''), | |
| 'creationTime': response.get('creationTime', ''), | |
| 'lastModifiedTime': response.get('lastModifiedTime', '') | |
| } | |
| except Exception as e: | |
| raise Exception(f"Error getting job status: {str(e)}") | |
| def wait_for_import_completion(self, job_name: str, max_wait_time: int = 3600) -> bool: | |
| """Wait for model import to complete.""" | |
| print(f"Waiting for import job {job_name} to complete...") | |
| print("This may take 30-60 minutes...") | |
| start_time = time.time() | |
| while time.time() - start_time < max_wait_time: | |
| status_info = self.get_job_status(job_name) | |
| status = status_info['status'] | |
| elapsed_time = int(time.time() - start_time) | |
| print(f"[{elapsed_time//60}m {elapsed_time%60}s] Status: {status}") | |
| if status == 'Completed': | |
| print("Model import completed successfully!") | |
| print(f"Model ARN: {status_info['modelArn']}") | |
| return True | |
| elif status == 'Failed': | |
| print(f"Model import failed: {status_info['statusMessage']}") | |
| return False | |
| elif status in ['InProgress', 'Validating']: | |
| if status_info.get('statusMessage'): | |
| print(status_info['statusMessage']) | |
| time.sleep(60) | |
| print("Timeout waiting for model import") | |
| return False | |
| def run_inference(self, model_arn: str, prompt: str, temperature: float = 0.7, | |
| max_tokens: int = 2048) -> str: | |
| """Run text inference on the imported model.""" | |
| try: | |
| print("Running inference...") | |
| body = { | |
| "prompt": prompt, | |
| "temperature": temperature, | |
| "max_gen_len": max_tokens, | |
| "top_p": 0.9, | |
| "top_k": 50 | |
| } | |
| response = self.bedrock_runtime.invoke_model( | |
| modelId=model_arn, | |
| body=json.dumps(body), | |
| accept='application/json', | |
| contentType='application/json' | |
| ) | |
| result = json.loads(response['body'].read().decode('utf-8')) | |
| if 'generation' in result: | |
| return result['generation'] | |
| elif 'choices' in result: | |
| return result['choices'][0]['text'] | |
| elif 'outputs' in result: | |
| return result['outputs'][0]['text'] | |
| elif 'text' in result: | |
| return result['text'] | |
| else: | |
| return str(result) | |
| except Exception as e: | |
| raise Exception(f"Error running inference: {str(e)}") | |
| def analyze_image(self, model_arn: str, image_path: str, prompt: str = "Describe this image.", | |
| temperature: float = 0.3, max_tokens: int = 4096) -> str: | |
| """Analyze an image using Qwen VL model.""" | |
| try: | |
| print(f"Analyzing image: {image_path}") | |
| with open(image_path, "rb") as image_file: | |
| image_base64 = base64.b64encode(image_file.read()).decode('utf-8') | |
| body = { | |
| 'prompt': prompt, | |
| 'temperature': temperature, | |
| 'max_gen_len': max_tokens, | |
| 'top_p': 0.9, | |
| 'images': [image_base64] | |
| } | |
| response = self.bedrock_runtime.invoke_model( | |
| modelId=model_arn, | |
| body=json.dumps(body), | |
| accept='application/json', | |
| contentType='application/json' | |
| ) | |
| result = json.loads(response['body'].read().decode('utf-8')) | |
| if 'generation' in result: | |
| return result['generation'] | |
| elif 'choices' in result: | |
| return result['choices'][0]['text'] | |
| elif 'outputs' in result: | |
| return result['outputs'][0]['text'] | |
| else: | |
| return str(result) | |
| except Exception as e: | |
| raise Exception(f"Error analyzing image: {str(e)}") | |
| def main(): | |
| """Main deployment workflow.""" | |
| try: | |
| print("AWS Bedrock Custom Model Import - Qwen Models") | |
| print("=" * 50) | |
| # Get configuration | |
| aws_region = input("Enter AWS region (default: us-east-1): ").strip() or "us-east-1" | |
| s3_bucket = input("Enter S3 bucket name: ").strip() | |
| # Initialize deployer | |
| deployer = QwenBedrockDeployer(aws_region) | |
| # Verify S3 bucket | |
| if not deployer.verify_s3_bucket(s3_bucket): | |
| print("Please create the S3 bucket or check permissions.") | |
| return | |
| # IAM role | |
| use_existing_role = input("Do you have an existing IAM role ARN? (y/n): ").strip().lower() == 'y' | |
| if use_existing_role: | |
| role_arn = input("Enter IAM role ARN: ").strip() | |
| else: | |
| role_arn = deployer.create_bedrock_service_role() | |
| # Model selection | |
| print("\nAvailable Qwen Models:") | |
| models = { | |
| "1": "Qwen/Qwen2.5-Coder-7B-Instruct", | |
| "2": "Qwen/Qwen2.5-7B-Instruct", | |
| "3": "Qwen/Qwen2.5-VL-7B-Instruct", | |
| "4": "Qwen/Qwen2.5-14B-Instruct", | |
| "5": "Custom" | |
| } | |
| for key, value in models.items(): | |
| print(f"{key}. {value}") | |
| choice = input("Select model (1-4): ").strip() | |
| if choice == "5": | |
| model_id = input("Enter custom model ID: ").strip() | |
| elif choice in models: | |
| model_id = models[choice] | |
| else: | |
| model_id = "Qwen/Qwen2.5-7B-Instruct" | |
| # Check for existing S3 model | |
| use_s3 = input("\nDo you have the model already uploaded to S3? (y/n): ").strip().lower() == 'y' | |
| if use_s3: | |
| s3_uri = input("Enter S3 URI (e.g., s3://bucket/path/to/model): ").strip() | |
| if not s3_uri.endswith('/'): | |
| s3_uri += '/' | |
| if not deployer.validate_s3_model(s3_uri): | |
| print("Invalid S3 location or missing model files.") | |
| use_s3 = False | |
| if not use_s3: | |
| # Download and upload model | |
| print("\nModel will be downloaded from Hugging Face and uploaded to S3.") | |
| model_name = model_id.replace('/', '-').replace('.', '-').lower() | |
| local_dir = f"./{model_name}" | |
| # Download | |
| deployer.download_model_from_huggingface(model_id, local_dir) | |
| # Upload | |
| s3_uri = deployer.upload_model_to_s3( | |
| local_dir, s3_bucket, f"models/{model_name}" | |
| ) | |
| # Generate job and model names | |
| model_name = model_id.replace('/', '-').replace('.', '-').lower() | |
| job_name = f"{model_name}-import-{int(time.time())}" | |
| custom_model_name = f"{model_name}-custom" | |
| print(f"\nDeployment Configuration:") | |
| print(f" Model: {model_id}") | |
| print(f" Job Name: {job_name}") | |
| print(f" Custom Model Name: {custom_model_name}") | |
| print(f" S3 URI: {s3_uri}") | |
| print(f" Role: {role_arn}") | |
| proceed = input("\nProceed with import? (y/n): ").strip().lower() | |
| if proceed != 'y': | |
| print("Import cancelled.") | |
| return | |
| # Create import job | |
| job_arn = deployer.create_model_import_job( | |
| job_name, custom_model_name, s3_uri, role_arn | |
| ) | |
| # Wait for completion | |
| if deployer.wait_for_import_completion(job_name): | |
| status_info = deployer.get_job_status(job_name) | |
| model_arn = status_info['modelArn'] | |
| print(f"\nModel deployed successfully!") | |
| print(f"Model ARN: {model_arn}") | |
| # Test inference | |
| test = input("\nTest the model? (y/n): ").strip().lower() | |
| if test == 'y': | |
| is_vision = 'vl' in model_id.lower() | |
| if is_vision: | |
| test_type = input("Test (1) text or (2) image? ").strip() | |
| if test_type == "2": | |
| image_path = input("Enter image path: ").strip() | |
| if os.path.exists(image_path): | |
| prompt = input("Enter prompt: ").strip() or "Describe this image." | |
| result = deployer.analyze_image(model_arn, image_path, prompt) | |
| print(f"\nResult:\n{result}") | |
| else: | |
| print("Image file not found.") | |
| else: | |
| prompt = input("Enter prompt: ").strip() | |
| if prompt: | |
| result = deployer.run_inference(model_arn, prompt) | |
| print(f"\nResult:\n{result}") | |
| else: | |
| prompt = input("Enter prompt: ").strip() | |
| if prompt: | |
| result = deployer.run_inference(model_arn, prompt) | |
| print(f"\nResult:\n{result}") | |
| print(f"\nSave this Model ARN for future use:") | |
| print(f" {model_arn}") | |
| else: | |
| print("Deployment failed. Check AWS console for details.") | |
| except KeyboardInterrupt: | |
| print("\nProcess interrupted.") | |
| except Exception as e: | |
| print(f"\nError: {str(e)}") | |
| if __name__ == "__main__": | |
| main() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Ensure that you have the latest version of boto3 installed and then run this script.