Last active
August 8, 2025 07:53
-
-
Save Quentin-M/b2c9eee1a9a1ac3074dfd96cd75091fc to your computer and use it in GitHub Desktop.
gen_ai_auth.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # | |
| # This script is meant to be used as a `credential_process` in an `~/.aws/config` file, | |
| # and triggers an AWS SSO Login flow when necessary then assumes a pre-configured `role_arn`. | |
| # This is useful for example to magically render compatible certain generative AI applications which | |
| # expect being provided a single valid AWS profile, and cannot by themselves perform the SSO flow required | |
| # to make the profile valid. | |
| # | |
| # - Create a virtual env: `python3 venv -m .venv && source ./.venv/bin/activate | |
| # - Install the dependencies: `pip3 install pyinstaller boto3 portalocker` | |
| # - Package up with dependencies using: `pyinstaller --onedir gen_ai_auth.py` | |
| # - Copy the executable: `cp -r dist/gen_ai_auth /usr/local/bin/` | |
| # - Configure your `~/.aws/config`: | |
| # [profile gen-ai-bedrock] | |
| # credential_process = /usr/local/bin/gen_ai_auth/gen_ai_auth | |
| # region = us-west-2 | |
| import boto3 | |
| import json | |
| import os | |
| import time | |
| import webbrowser | |
| import hashlib | |
| import portalocker | |
| from datetime import datetime, timedelta, timezone | |
| SSO_ACCOUNT_ID = "352991325246" | |
| SSO_ROLE_NAME = "GenerativeAI" | |
| SSO_START_URL = "https://d-9367747b8a.awsapps.com/start" | |
| SSO_REGION = "eu-west-1" | |
| GEN_AI_ROLE = "arn:aws:iam::194722420547:role/bedrock-user" | |
| GEN_AI_REGION = "us-west-2" | |
| SSO_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".aws", "sso", "cache") | |
| SSO_CACHE_PATH = os.path.join(SSO_CACHE_DIR, hashlib.sha1(f"{SSO_START_URL}{SSO_REGION}".encode("utf-8")).hexdigest() + ".json") | |
| SSO_LOCK_PATH = os.path.join(SSO_CACHE_DIR, "gen_ai.lock") | |
| SSO_LOCK_TTL = timedelta(minutes=1) | |
| # Take a file-lock | |
| def take_flock(): | |
| if not os.path.exists(SSO_CACHE_DIR): | |
| os.makedirs(SSO_CACHE_DIR, exist_ok=True) | |
| if os.path.exists(SSO_LOCK_PATH) and (datetime.now() - datetime.fromtimestamp(os.path.getmtime(SSO_LOCK_PATH))) > SSO_LOCK_TTL: | |
| os.remove(SSO_LOCK_PATH) | |
| return portalocker.Lock(SSO_LOCK_PATH, timeout=60) | |
| # Get a cached AWS Access Token | |
| def get_sso_cached_token(): | |
| if not os.path.exists(SSO_CACHE_PATH): | |
| return None | |
| with open(SSO_CACHE_PATH, "r") as f: | |
| try: | |
| data = json.load(f) | |
| expires_at_str = data.get("expiresAt") | |
| if expires_at_str and datetime.strptime(expires_at_str, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) > datetime.now(timezone.utc): | |
| return data.get("accessToken") | |
| except (json.JSONDecodeError, ValueError) as e: | |
| raise Exception(f"Error reading cache file at {SSO_CACHE_PATH}: {e}") | |
| # Cache the given AWS Access Token | |
| def store_sso_cached_token(sso_client, sso_access_token): | |
| expires_at = datetime.now(timezone.utc) + timedelta(seconds=sso_access_token["expiresIn"]) | |
| secret_expires_at = datetime.fromtimestamp(sso_client["clientSecretExpiresAt"], tz=timezone.utc) | |
| cache_data = { | |
| "startUrl": SSO_START_URL, | |
| "region": SSO_REGION, | |
| "accessToken": sso_access_token["accessToken"], | |
| "expiresAt": expires_at.strftime("%Y-%m-%dT%H:%M:%SZ"), | |
| "clientId": sso_client["clientId"], | |
| "clientSecret": sso_client["clientSecret"], | |
| "registrationExpiresAt": secret_expires_at.strftime("%Y-%m-%dT%H:%M:%SZ"), | |
| } | |
| with open(SSO_CACHE_PATH, "w") as f: | |
| json.dump(cache_data, f, indent=4) | |
| # Register an AWS OIDC Client | |
| def sso_register(): | |
| try: | |
| oidc_client = boto3.client("sso-oidc", region_name=SSO_REGION) | |
| client = oidc_client.register_client( | |
| clientName="gen-ai-sso", | |
| clientType="public", | |
| ) | |
| return client | |
| except Exception as e: | |
| raise Exception(f"Exception in sso_register: {e}") | |
| # Initializes the AWS SSO flow & prompt the User for authorization | |
| def sso_start(client_registration): | |
| try: | |
| oidc_client = boto3.client("sso-oidc", region_name=SSO_REGION) | |
| sso_device_auth_resp = oidc_client.start_device_authorization( | |
| clientId=client_registration["clientId"], | |
| clientSecret=client_registration["clientSecret"], | |
| startUrl=SSO_START_URL, | |
| ) | |
| webbrowser.open(sso_device_auth_resp["verificationUriComplete"], new=2) | |
| return sso_device_auth_resp | |
| except Exception as e: | |
| raise Exception(f"Exception in sso_start: {e}") | |
| # Create AWS Access Token once User authorizes the flow | |
| def sso_token(client_registration, device_authorization_response): | |
| oidc_client = boto3.client("sso-oidc", region_name=SSO_REGION) | |
| device_code = device_authorization_response["deviceCode"] | |
| interval = device_authorization_response["interval"] | |
| expires_in = device_authorization_response["expiresIn"] | |
| start_time = time.time() | |
| while (time.time() - start_time) < expires_in: | |
| try: | |
| token_response = oidc_client.create_token( | |
| clientId=client_registration["clientId"], | |
| clientSecret=client_registration["clientSecret"], | |
| deviceCode=device_code, | |
| grantType="urn:ietf:params:oauth:grant-type:device_code" | |
| ) | |
| return token_response | |
| except oidc_client.exceptions.AuthorizationPendingException: | |
| time.sleep(interval) | |
| except oidc_client.exceptions.SlowDownException: | |
| time.sleep(interval * 2) | |
| except oidc_client.exceptions.ExpiredTokenException: | |
| raise Exception("The device code has expired. Please restart the process.") | |
| except Exception as e: | |
| raise Exception(f"An unexpected error occurred during token retrieval: {e}") | |
| raise Exception("Authorization timed out. Please restart the process.") | |
| # Convert our AWS Session into AWS Access/Secret/Token keys | |
| def sso_credentials(access_token): | |
| try: | |
| sso_client = boto3.client("sso", region_name=SSO_REGION) | |
| credentials = sso_client.get_role_credentials( | |
| roleName=SSO_ROLE_NAME, | |
| accountId=SSO_ACCOUNT_ID, | |
| accessToken=access_token | |
| ) | |
| return credentials["roleCredentials"] | |
| except Exception as e: | |
| raise Exception("Exception in sso_credentials: {e}") | |
| # Perform the end-to-end AWS SSO flow | |
| def sso_do(): | |
| # Avoid interference, or profile recursion - otherwise the script could call itself.. | |
| for key in [key for key in os.environ if key.startswith('AWS_')]: | |
| del os.environ[key] | |
| # Cached session? | |
| sso_access_token = get_sso_cached_token() | |
| if sso_access_token: | |
| return sso_access_token | |
| # Take a file-lock to avoid spamming the SSO flow - apps often call AWS APIs concurrently | |
| try: | |
| with take_flock() as fh: | |
| # Retry condition | |
| sso_access_token = get_sso_cached_token() | |
| if sso_access_token: | |
| return sso_access_token | |
| # Perform the SSO flow | |
| sso_client = sso_register() | |
| sso_device_auth_resp = sso_start(sso_client) | |
| sso_access_token = sso_token(sso_client, sso_device_auth_resp) | |
| store_sso_cached_token(sso_client, sso_access_token) | |
| return sso_access_token["accessToken"] | |
| except portalocker.LockException: | |
| raise Exception(f"Exception in sso_do: cannot acquire lock on {SSO_LOCK_PATH}") | |
| # Role-hop using our AWS Credentials into the targeted AWS Role | |
| def assume_role(sso_creds): | |
| try: | |
| session = boto3.Session( | |
| aws_access_key_id=sso_creds["accessKeyId"], | |
| aws_secret_access_key=sso_creds["secretAccessKey"], | |
| aws_session_token=sso_creds["sessionToken"], | |
| region_name=GEN_AI_REGION | |
| ) | |
| sts_client = session.client("sts") | |
| assumed_role_object = sts_client.assume_role( | |
| RoleArn=GEN_AI_ROLE, | |
| RoleSessionName="AssumeRoleSession" | |
| ) | |
| credentials = assumed_role_object["Credentials"] | |
| return { | |
| "Version": 1, | |
| "AccessKeyId": credentials["AccessKeyId"], | |
| "SecretAccessKey": credentials["SecretAccessKey"], | |
| "SessionToken": credentials["SessionToken"], | |
| "Expiration": credentials["Expiration"].isoformat() | |
| } | |
| except Exception as e: | |
| raise Exception(f"Exception in assume_credentials: {e}") | |
| if __name__ == "__main__": | |
| sso_access_token = sso_do() | |
| sso_creds = sso_credentials(sso_access_token) | |
| print(json.dumps(assume_role(sso_creds))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment