Skip to content

Instantly share code, notes, and snippets.

@Quentin-M
Last active August 8, 2025 07:53
Show Gist options
  • Select an option

  • Save Quentin-M/b2c9eee1a9a1ac3074dfd96cd75091fc to your computer and use it in GitHub Desktop.

Select an option

Save Quentin-M/b2c9eee1a9a1ac3074dfd96cd75091fc to your computer and use it in GitHub Desktop.
gen_ai_auth.py
#!/usr/bin/env python3
#
# This script is meant to be used as a `credential_process` in an `~/.aws/config` file,
# and triggers an AWS SSO Login flow when necessary then assumes a pre-configured `role_arn`.
# This is useful for example to magically render compatible certain generative AI applications which
# expect being provided a single valid AWS profile, and cannot by themselves perform the SSO flow required
# to make the profile valid.
#
# - Create a virtual env: `python3 venv -m .venv && source ./.venv/bin/activate
# - Install the dependencies: `pip3 install pyinstaller boto3 portalocker`
# - Package up with dependencies using: `pyinstaller --onedir gen_ai_auth.py`
# - Copy the executable: `cp -r dist/gen_ai_auth /usr/local/bin/`
# - Configure your `~/.aws/config`:
# [profile gen-ai-bedrock]
# credential_process = /usr/local/bin/gen_ai_auth/gen_ai_auth
# region = us-west-2
import boto3
import json
import os
import time
import webbrowser
import hashlib
import portalocker
from datetime import datetime, timedelta, timezone
SSO_ACCOUNT_ID = "352991325246"
SSO_ROLE_NAME = "GenerativeAI"
SSO_START_URL = "https://d-9367747b8a.awsapps.com/start"
SSO_REGION = "eu-west-1"
GEN_AI_ROLE = "arn:aws:iam::194722420547:role/bedrock-user"
GEN_AI_REGION = "us-west-2"
SSO_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".aws", "sso", "cache")
SSO_CACHE_PATH = os.path.join(SSO_CACHE_DIR, hashlib.sha1(f"{SSO_START_URL}{SSO_REGION}".encode("utf-8")).hexdigest() + ".json")
SSO_LOCK_PATH = os.path.join(SSO_CACHE_DIR, "gen_ai.lock")
SSO_LOCK_TTL = timedelta(minutes=1)
# Take a file-lock
def take_flock():
if not os.path.exists(SSO_CACHE_DIR):
os.makedirs(SSO_CACHE_DIR, exist_ok=True)
if os.path.exists(SSO_LOCK_PATH) and (datetime.now() - datetime.fromtimestamp(os.path.getmtime(SSO_LOCK_PATH))) > SSO_LOCK_TTL:
os.remove(SSO_LOCK_PATH)
return portalocker.Lock(SSO_LOCK_PATH, timeout=60)
# Get a cached AWS Access Token
def get_sso_cached_token():
if not os.path.exists(SSO_CACHE_PATH):
return None
with open(SSO_CACHE_PATH, "r") as f:
try:
data = json.load(f)
expires_at_str = data.get("expiresAt")
if expires_at_str and datetime.strptime(expires_at_str, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) > datetime.now(timezone.utc):
return data.get("accessToken")
except (json.JSONDecodeError, ValueError) as e:
raise Exception(f"Error reading cache file at {SSO_CACHE_PATH}: {e}")
# Cache the given AWS Access Token
def store_sso_cached_token(sso_client, sso_access_token):
expires_at = datetime.now(timezone.utc) + timedelta(seconds=sso_access_token["expiresIn"])
secret_expires_at = datetime.fromtimestamp(sso_client["clientSecretExpiresAt"], tz=timezone.utc)
cache_data = {
"startUrl": SSO_START_URL,
"region": SSO_REGION,
"accessToken": sso_access_token["accessToken"],
"expiresAt": expires_at.strftime("%Y-%m-%dT%H:%M:%SZ"),
"clientId": sso_client["clientId"],
"clientSecret": sso_client["clientSecret"],
"registrationExpiresAt": secret_expires_at.strftime("%Y-%m-%dT%H:%M:%SZ"),
}
with open(SSO_CACHE_PATH, "w") as f:
json.dump(cache_data, f, indent=4)
# Register an AWS OIDC Client
def sso_register():
try:
oidc_client = boto3.client("sso-oidc", region_name=SSO_REGION)
client = oidc_client.register_client(
clientName="gen-ai-sso",
clientType="public",
)
return client
except Exception as e:
raise Exception(f"Exception in sso_register: {e}")
# Initializes the AWS SSO flow & prompt the User for authorization
def sso_start(client_registration):
try:
oidc_client = boto3.client("sso-oidc", region_name=SSO_REGION)
sso_device_auth_resp = oidc_client.start_device_authorization(
clientId=client_registration["clientId"],
clientSecret=client_registration["clientSecret"],
startUrl=SSO_START_URL,
)
webbrowser.open(sso_device_auth_resp["verificationUriComplete"], new=2)
return sso_device_auth_resp
except Exception as e:
raise Exception(f"Exception in sso_start: {e}")
# Create AWS Access Token once User authorizes the flow
def sso_token(client_registration, device_authorization_response):
oidc_client = boto3.client("sso-oidc", region_name=SSO_REGION)
device_code = device_authorization_response["deviceCode"]
interval = device_authorization_response["interval"]
expires_in = device_authorization_response["expiresIn"]
start_time = time.time()
while (time.time() - start_time) < expires_in:
try:
token_response = oidc_client.create_token(
clientId=client_registration["clientId"],
clientSecret=client_registration["clientSecret"],
deviceCode=device_code,
grantType="urn:ietf:params:oauth:grant-type:device_code"
)
return token_response
except oidc_client.exceptions.AuthorizationPendingException:
time.sleep(interval)
except oidc_client.exceptions.SlowDownException:
time.sleep(interval * 2)
except oidc_client.exceptions.ExpiredTokenException:
raise Exception("The device code has expired. Please restart the process.")
except Exception as e:
raise Exception(f"An unexpected error occurred during token retrieval: {e}")
raise Exception("Authorization timed out. Please restart the process.")
# Convert our AWS Session into AWS Access/Secret/Token keys
def sso_credentials(access_token):
try:
sso_client = boto3.client("sso", region_name=SSO_REGION)
credentials = sso_client.get_role_credentials(
roleName=SSO_ROLE_NAME,
accountId=SSO_ACCOUNT_ID,
accessToken=access_token
)
return credentials["roleCredentials"]
except Exception as e:
raise Exception("Exception in sso_credentials: {e}")
# Perform the end-to-end AWS SSO flow
def sso_do():
# Avoid interference, or profile recursion - otherwise the script could call itself..
for key in [key for key in os.environ if key.startswith('AWS_')]:
del os.environ[key]
# Cached session?
sso_access_token = get_sso_cached_token()
if sso_access_token:
return sso_access_token
# Take a file-lock to avoid spamming the SSO flow - apps often call AWS APIs concurrently
try:
with take_flock() as fh:
# Retry condition
sso_access_token = get_sso_cached_token()
if sso_access_token:
return sso_access_token
# Perform the SSO flow
sso_client = sso_register()
sso_device_auth_resp = sso_start(sso_client)
sso_access_token = sso_token(sso_client, sso_device_auth_resp)
store_sso_cached_token(sso_client, sso_access_token)
return sso_access_token["accessToken"]
except portalocker.LockException:
raise Exception(f"Exception in sso_do: cannot acquire lock on {SSO_LOCK_PATH}")
# Role-hop using our AWS Credentials into the targeted AWS Role
def assume_role(sso_creds):
try:
session = boto3.Session(
aws_access_key_id=sso_creds["accessKeyId"],
aws_secret_access_key=sso_creds["secretAccessKey"],
aws_session_token=sso_creds["sessionToken"],
region_name=GEN_AI_REGION
)
sts_client = session.client("sts")
assumed_role_object = sts_client.assume_role(
RoleArn=GEN_AI_ROLE,
RoleSessionName="AssumeRoleSession"
)
credentials = assumed_role_object["Credentials"]
return {
"Version": 1,
"AccessKeyId": credentials["AccessKeyId"],
"SecretAccessKey": credentials["SecretAccessKey"],
"SessionToken": credentials["SessionToken"],
"Expiration": credentials["Expiration"].isoformat()
}
except Exception as e:
raise Exception(f"Exception in assume_credentials: {e}")
if __name__ == "__main__":
sso_access_token = sso_do()
sso_creds = sso_credentials(sso_access_token)
print(json.dumps(assume_role(sso_creds)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment